diff --git a/src/instamatic/TEMController/TEMController.py b/src/instamatic/TEMController/TEMController.py index 94febe95..d240ef4e 100644 --- a/src/instamatic/TEMController/TEMController.py +++ b/src/instamatic/TEMController/TEMController.py @@ -1,15 +1,17 @@ import time from collections import namedtuple from concurrent.futures import ThreadPoolExecutor -from typing import Tuple +from typing import Optional, Tuple import numpy as np from instamatic import config from instamatic.camera import Camera +from instamatic.camera.camera_base import CameraBase from instamatic.exceptions import TEMControllerError from instamatic.formats import write_tiff from instamatic.image_utils import rotate_image +from instamatic.TEMController.microscope_base import MicroscopeBase from .deflectors import * from .lenses import * @@ -89,7 +91,7 @@ class TEMController: cam: Camera control object (see instamatic.camera) [optional] """ - def __init__(self, tem, cam=None): + def __init__(self, tem: MicroscopeBase, cam: Optional[CameraBase] = None): super().__init__() self._executor = ThreadPoolExecutor(max_workers=1) @@ -122,7 +124,7 @@ def __init__(self, tem, cam=None): def __repr__(self): return (f'Mode: {self.tem.getFunctionMode()}\n' - f'High tension: {self.high_tension/1000:.0f} kV\n' + f'High tension: {self.high_tension / 1000:.0f} kV\n' f'Current density: {self.current_density:.2f} pA/cm2\n' f'{self.gunshift}\n' f'{self.guntilt}\n' @@ -244,7 +246,7 @@ def run_script(self, script: str, verbose: bool = True) -> None: t1 = time.perf_counter() if verbose: - print(f'\nScript finished in {t1-t0:.4f} s') + print(f'\nScript finished in {t1 - t0:.4f} s') def get_stagematrix(self, binning: int = None, mag: int = None, mode: int = None): """Helper function to get the stage matrix from the config file. The diff --git a/src/instamatic/TEMController/fei_microscope.py b/src/instamatic/TEMController/fei_microscope.py index 7afa5e42..81c55031 100644 --- a/src/instamatic/TEMController/fei_microscope.py +++ b/src/instamatic/TEMController/fei_microscope.py @@ -7,6 +7,7 @@ from instamatic import config from instamatic.exceptions import FEIValueError, TEMCommunicationError +from instamatic.TEMController.microscope_base import MicroscopeBase logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ def get_camera_length_mapping(): CameraLengthMapping = get_camera_length_mapping() -class FEIMicroscope: +class FEIMicroscope(MicroscopeBase): """Python bindings to the FEI microscope using the COM interface.""" def __init__(self, name='fei'): diff --git a/src/instamatic/TEMController/fei_simu_microscope.py b/src/instamatic/TEMController/fei_simu_microscope.py index 41db8258..3c62f462 100644 --- a/src/instamatic/TEMController/fei_simu_microscope.py +++ b/src/instamatic/TEMController/fei_simu_microscope.py @@ -7,6 +7,7 @@ from instamatic import config from instamatic.exceptions import FEIValueError, TEMCommunicationError +from instamatic.TEMController.microscope_base import MicroscopeBase logger = logging.getLogger(__name__) @@ -17,7 +18,7 @@ MAX = 1.0 -class FEISimuMicroscope: +class FEISimuMicroscope(MicroscopeBase): """Python bindings to the FEI simulated microscope using the COM interface.""" diff --git a/src/instamatic/TEMController/jeol_microscope.py b/src/instamatic/TEMController/jeol_microscope.py index 49fbe4fc..346e2840 100644 --- a/src/instamatic/TEMController/jeol_microscope.py +++ b/src/instamatic/TEMController/jeol_microscope.py @@ -7,6 +7,7 @@ from instamatic import config from instamatic.exceptions import JEOLValueError, TEMCommunicationError, TEMValueError +from instamatic.TEMController.microscope_base import MicroscopeBase logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ # Piezo stage seems to operate on a different level than standard XY -class JeolMicroscope: +class JeolMicroscope(MicroscopeBase): """Python bindings to the JEOL microscope using the COM interface.""" def __init__(self, name: str = 'jeol'): diff --git a/src/instamatic/TEMController/microscope.py b/src/instamatic/TEMController/microscope.py index d52b4a2a..ec16e2f2 100644 --- a/src/instamatic/TEMController/microscope.py +++ b/src/instamatic/TEMController/microscope.py @@ -1,11 +1,12 @@ from instamatic import config +from instamatic.TEMController.microscope_base import MicroscopeBase default_tem_interface = config.microscope.interface __all__ = ['Microscope', 'get_tem'] -def get_tem(interface: str): +def get_tem(interface: str) -> 'type[MicroscopeBase]': """Grab tem class with the specific 'interface'.""" simulate = config.settings.simulate @@ -28,7 +29,7 @@ def get_tem(interface: str): return cls -def Microscope(name: str = None, use_server: bool = False): +def Microscope(name: str = None, use_server: bool = False) -> MicroscopeBase: """Generic class to load microscope interface class. name: str diff --git a/src/instamatic/TEMController/microscope_base.py b/src/instamatic/TEMController/microscope_base.py new file mode 100644 index 00000000..52865d9e --- /dev/null +++ b/src/instamatic/TEMController/microscope_base.py @@ -0,0 +1,176 @@ +from abc import ABC, abstractmethod +from typing import Tuple + + +class MicroscopeBase(ABC): + @abstractmethod + def getBeamShift(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getBeamTilt(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getBrightness(self) -> int: + pass + + @abstractmethod + def getCondensorLensStigmator(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getCurrentDensity(self) -> float: + pass + + @abstractmethod + def getDiffFocus(self, confirm_mode: bool) -> int: + pass + + @abstractmethod + def getDiffShift(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getFunctionMode(self) -> str: + pass + + @abstractmethod + def getGunShift(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getGunTilt(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getHTValue(self) -> float: + pass + + @abstractmethod + def getImageShift1(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getImageShift2(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getIntermediateLensStigmator(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getMagnification(self) -> int: + pass + + @abstractmethod + def getMagnificationAbsoluteIndex(self) -> int: + pass + + @abstractmethod + def getMagnificationIndex(self) -> int: + pass + + @abstractmethod + def getMagnificationRanges(self) -> dict: + pass + + @abstractmethod + def getObjectiveLensStigmator(self) -> Tuple[int, int]: + pass + + @abstractmethod + def getSpotSize(self) -> int: + pass + + @abstractmethod + def getStagePosition(self) -> Tuple[int, int, int, int, int]: + pass + + @abstractmethod + def isBeamBlanked(self) -> bool: + pass + + @abstractmethod + def isStageMoving(self) -> bool: + pass + + @abstractmethod + def release_connection(self) -> None: + pass + + @abstractmethod + def setBeamBlank(self, mode: bool) -> None: + pass + + @abstractmethod + def setBeamShift(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setBeamTilt(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setBrightness(self, value: int) -> None: + pass + + @abstractmethod + def setCondensorLensStigmator(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setDiffFocus(self, value: int, confirm_mode: bool) -> None: + pass + + @abstractmethod + def setDiffShift(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setFunctionMode(self, value: int) -> None: + pass + + @abstractmethod + def setGunShift(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setGunTilt(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setImageShift1(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setImageShift2(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setIntermediateLensStigmator(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setMagnification(self, value: int) -> None: + pass + + @abstractmethod + def setMagnificationIndex(self, index: int) -> None: + pass + + @abstractmethod + def setObjectiveLensStigmator(self, x: int, y: int) -> None: + pass + + @abstractmethod + def setSpotSize(self, value: int) -> None: + pass + + @abstractmethod + def setStagePosition(self, x: int, y: int, z: int, a: int, b: int, wait: bool) -> None: + pass + + @abstractmethod + def stopStage(self) -> None: + pass diff --git a/src/instamatic/TEMController/simu_microscope.py b/src/instamatic/TEMController/simu_microscope.py index 6d7a4524..7492464f 100644 --- a/src/instamatic/TEMController/simu_microscope.py +++ b/src/instamatic/TEMController/simu_microscope.py @@ -4,6 +4,7 @@ from instamatic import config from instamatic.exceptions import TEMValueError +from instamatic.TEMController.microscope_base import MicroscopeBase NTRLMAPPING = { 'GUN1': 0, @@ -30,7 +31,7 @@ MIN = 0 -class SimuMicroscope: +class SimuMicroscope(MicroscopeBase): """Simulates a microscope connection. Has the same variables as the real JEOL/FEI equivalents, but does diff --git a/src/instamatic/camera/camera.py b/src/instamatic/camera/camera.py index 0e811700..b37c3493 100644 --- a/src/instamatic/camera/camera.py +++ b/src/instamatic/camera/camera.py @@ -24,7 +24,7 @@ def get_cam(interface: str = None): elif interface == 'gatansocket': from instamatic.camera.camera_gatan2 import CameraGatan2 as cam elif interface in ('timepix', 'pytimepix'): - from instamatic.camera import camera_timepix as cam + from instamatic.camera.camera_timepix import CameraTPX as cam elif interface in ('emmenu', 'tvips'): from instamatic.camera.camera_emmenu import CameraEMMENU as cam elif interface == 'serval': diff --git a/src/instamatic/camera/camera_base.py b/src/instamatic/camera/camera_base.py new file mode 100644 index 00000000..3693672c --- /dev/null +++ b/src/instamatic/camera/camera_base.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple + +from numpy import ndarray + +from instamatic import config + + +class CameraBase(ABC): + + # Set manually + name: str + streamable: bool + + # Set by `load_defaults` + camera_rotation_vs_stage_xy: float + default_binsize: int + default_exposure: float + dimensions: Tuple[int, int] + interface: str + possible_binsizes: List[int] + stretch_amplitude: float + stretch_azimuth: float + + @abstractmethod + def __init__(self, name: str): + self.name = name + self.load_defaults() + + @abstractmethod + def establish_connection(self): + pass + + @abstractmethod + def release_connection(self): + pass + + @abstractmethod + def get_image( + self, exposure: float = None, binsize: int = None, **kwargs + ) -> ndarray: + pass + + def get_movie( + self, n_frames: int, exposure: float = None, binsize: int = None, **kwargs + ) -> List[ndarray]: + """Basic implementation, subclasses should override with appropriate + optimization.""" + return [ + self.get_image(exposure=exposure, binsize=binsize, **kwargs) + for _ in range(n_frames) + ] + + def __enter__(self): + self.establish_connection() + return self + + def __exit__(self, kind, value, traceback): + self.release_connection() + + def get_camera_dimensions(self) -> Tuple[int, int]: + return self.dimensions + + def get_name(self) -> str: + return self.name + + def load_defaults(self): + if self.name != config.settings.camera: + config.load_camera_config(camera_name=self.name) + for key, val in config.camera.mapping.items(): + setattr(self, key, val) diff --git a/src/instamatic/camera/camera_emmenu.py b/src/instamatic/camera/camera_emmenu.py index 85284022..bad7495a 100644 --- a/src/instamatic/camera/camera_emmenu.py +++ b/src/instamatic/camera/camera_emmenu.py @@ -7,6 +7,7 @@ import numpy as np from instamatic import config +from instamatic.camera.camera_base import CameraBase logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def EMVector2dict(vec): return d -class CameraEMMENU: +class CameraEMMENU(CameraBase): """Software interface for the EMMENU program. Communicates with EMMENU over the COM interface defined by TVIPS @@ -58,6 +59,7 @@ class CameraEMMENU: name : str Name of the interface """ + streamable = False def __init__( self, @@ -65,15 +67,13 @@ def __init__( name: str = 'emmenu', ): """Initialize camera module.""" - super().__init__() + super().__init__(name) try: comtypes.CoInitializeEx(comtypes.COINIT_MULTITHREADED) except OSError: comtypes.CoInitialize() - self.name = name - self._obj = comtypes.client.CreateObject('EMMENU4.EMMENUApplication.1', comtypes.CLSCTX_ALL) self._recording = False @@ -116,22 +116,12 @@ def __init__( self._vp.DirectoryHandle = self.drc_index # set current directory - self.load_defaults() - msg = f'Camera `{self.get_camera_name()}` ({self.name}) initialized' # print(msg) logger.info(msg) atexit.register(self.release_connection) - def load_defaults(self) -> None: - if self.name != config.settings.camera: - config.load_camera_config(camera_name=self.name) - - self.__dict__.update(config.camera.mapping) - - self.streamable = False - def list_configs(self) -> list: """List the configs from the Configuration Manager.""" print(f'Configurations for camera {self.name}') @@ -142,7 +132,7 @@ def list_configs(self) -> list: for i, cfg in enumerate(self._obj.CameraConfigurations): is_selected = (current == cfg.Name) end = ' (selected)' if is_selected else '' - print(f'{i+1:2d} - {cfg.Name}{end}') + print(f'{i + 1:2d} - {cfg.Name}{end}') lst.append(cfg.Name) return lst @@ -400,7 +390,7 @@ def write_tiffs(self, start_index: int, stop_index: int, path: str, clear_buffer # self._immgr.DeleteImageBuffer(drc_index, image_index) # does not work on 3200 self._emi.DeleteImage(p) # also clears from buffer - print(f'Wrote {i+1} images to {path}') + print(f'Wrote {i + 1} images to {path}') def get_image(self, **kwargs) -> 'np.array': """Acquire image through EMMENU and return data as np array.""" @@ -506,6 +496,10 @@ def release_connection(self) -> None: comtypes.CoUninitialize() + def establish_connection(self): + # All is handled in the constructor + pass + if __name__ == '__main__': cam = CameraEMMENU() diff --git a/src/instamatic/camera/camera_gatan.py b/src/instamatic/camera/camera_gatan.py index 3b8ddddb..9263aff6 100644 --- a/src/instamatic/camera/camera_gatan.py +++ b/src/instamatic/camera/camera_gatan.py @@ -20,6 +20,7 @@ import numpy as np from instamatic import config +from instamatic.camera.camera_base import CameraBase logger = logging.getLogger(__name__) @@ -71,8 +72,9 @@ SYMBOLS['simu'] = SYMBOLS['actual'] -class CameraDLL: +class CameraDLL(CameraBase): """Interface with the CCDCOM DLLs to connect to the gatan software.""" + streamable = False def __init__(self, name: str = 'gatan'): """Initialize camera module. @@ -81,7 +83,7 @@ def __init__(self, name: str = 'gatan'): 'gatan' 'simulateDLL' """ - super().__init__() + super().__init__(name) cameradir = Path(__file__).parent @@ -94,8 +96,6 @@ def __init__(self, name: str = 'gatan'): else: raise ValueError(f'No such camera: {name}') - self.name = name - try: lib = ctypes.cdll.LoadLibrary(str(libpath)) except OSError as e: @@ -130,8 +130,6 @@ def __init__(self, name: str = 'gatan'): self.establish_connection() - self.load_defaults() - msg = f'Camera {self.get_name()} initialized' logger.info(msg) @@ -141,14 +139,6 @@ def __init__(self, name: str = 'gatan'): atexit.register(self.release_connection) - def load_defaults(self): - if self.name != config.settings.camera: - config.load_camera_config(camera_name=self.name) - - self.__dict__.update(config.camera.mapping) - - self.streamable = False - def get_image(self, exposure=None, binsize=None, **kwargs) -> np.ndarray: """Image acquisition routine. diff --git a/src/instamatic/camera/camera_gatan2.py b/src/instamatic/camera/camera_gatan2.py index cdefc84d..f6f86a15 100644 --- a/src/instamatic/camera/camera_gatan2.py +++ b/src/instamatic/camera/camera_gatan2.py @@ -7,19 +7,19 @@ import numpy as np from instamatic import config +from instamatic.camera.camera_base import CameraBase from instamatic.camera.gatansocket3 import GatanSocket logger = logging.getLogger(__name__) -class CameraGatan2: +class CameraGatan2(CameraBase): """Connect to Digital Microsgraph using the SerialEM Plugin.""" + streamable = False def __init__(self, name: str = 'gatan2'): """Initialize camera module.""" - super().__init__() - - self.name = name + super().__init__(name) self.g = GatanSocket() @@ -33,14 +33,6 @@ def __init__(self, name: str = 'gatan2'): atexit.register(self.release_connection) - def load_defaults(self) -> None: - if self.name != config.settings.camera: - config.load_camera_config(camera_name=self.name) - - self.__dict__.update(config.camera.mapping) - - self.streamable = False - def get_camera_type(self) -> str: """Get the name of the camera currently in use.""" raise NotImplementedError @@ -49,10 +41,6 @@ def get_dm_version(self) -> str: """Get the version number of DM.""" return self.g.GetDMVersion() - def get_camera_dimensions(self) -> (int, int): - """Get the maximum dimensions reported by the camera.""" - raise NotImplementedError - def get_image_dimensions(self) -> (int, int): """Get the dimensions of the image.""" binning = self.get_binning() @@ -84,7 +72,7 @@ def write_tiffs(self) -> None: path = Path(path) i = 0 - print(f'Wrote {i+1} images to {path}') + print(f'Wrote {i + 1} images to {path}') def get_image(self, exposure=0.400, @@ -157,8 +145,14 @@ def get_exposure(self) -> int: """Return exposure time in ms.""" raise NotImplementedError + def establish_connection(self): + # Already done by the constructor of GatanSocket + # self.g.connect() + pass + def release_connection(self) -> None: """Release the connection to the camera.""" + self.g.disconnect() msg = f'Connection to camera `{self.get_camera_name()}` ({self.name}) released' # print(msg) logger.info(msg) diff --git a/src/instamatic/camera/camera_merlin.py b/src/instamatic/camera/camera_merlin.py index 89a56b81..969a99bc 100644 --- a/src/instamatic/camera/camera_merlin.py +++ b/src/instamatic/camera/camera_merlin.py @@ -27,11 +27,12 @@ import logging import socket import time -from typing import Any +from typing import Any, List import numpy as np from instamatic import config +from instamatic.camera.camera_base import CameraBase try: from .merlin_io import load_mib @@ -64,28 +65,26 @@ def MPX_CMD(type_cmd: str = 'GET', cmd: str = 'DETECTORSTATUS') -> bytes: """ length = len(cmd) # tmp = 'MPX,00000000' + str(length+5) + ',' + type_cmd + ',' + cmd - tmp = f'MPX,00000000{length+5},{type_cmd},{cmd}' + tmp = f'MPX,00000000{length + 5},{type_cmd},{cmd}' logger.debug(tmp) return tmp.encode() -class CameraMerlin: +class CameraMerlin(CameraBase): """Camera interface for the Quantum Detectors Merlin camera.""" START_SIZE = 14 MAX_NUMFRAMESTOACQUIRE = 42_949_672_950 + streamable = True def __init__(self, name='merlin'): """Initialize camera module.""" - super().__init__() + super().__init__(name) - self.name = name self._state = {} self._soft_trigger_mode = False self._soft_trigger_exposure = None - self.load_defaults() - self.establish_connection() self.establish_data_connection() @@ -94,14 +93,6 @@ def __init__(self, name='merlin'): atexit.register(self.release_connection) - def load_defaults(self): - if self.name != config.settings.camera: - config.load_camera_config(camera_name=self.name) - - self.streamable = True - - self.__dict__.update(config.camera.mapping) - def receive_data(self, *, nbytes: int) -> bytearray: """Safely receive from the socket until `n_bytes` of data are received.""" @@ -348,14 +339,6 @@ def get_image_dimensions(self) -> (int, int): return dim_x, dim_y - def get_camera_dimensions(self) -> (int, int): - """Get the dimensions reported by the camera.""" - return self.dimensions - - def get_name(self) -> str: - """Get the name reported by the camera.""" - return self.name - def establish_connection(self) -> None: """Establish connection to command port of the merlin software.""" # Create command socket @@ -430,7 +413,7 @@ def test_movie(cam): overhead = avg_frametime - exposure print(f'\nExposure: {exposure}, frames: {n_frames}') - print(f'\nTotal time: {t1-t0:.3f} s - acq. time: {avg_frametime:.3f} s - overhead: {overhead:.3f}') + print(f'\nTotal time: {t1 - t0:.3f} s - acq. time: {avg_frametime:.3f} s - overhead: {overhead:.3f}') for frame in frames: assert frame.shape == (512, 512) @@ -454,7 +437,7 @@ def test_single_frame(cam): overhead = avg_frametime - exposure print(f'\nExposure: {exposure}, frames: {n_frames}') - print(f'Total time: {t1-t0:.3f} s - acq. time: {avg_frametime:.3f} s - overhead: {overhead:.3f}') + print(f'Total time: {t1 - t0:.3f} s - acq. time: {avg_frametime:.3f} s - overhead: {overhead:.3f}') def test_plot_single_image(cam): diff --git a/src/instamatic/camera/camera_serval.py b/src/instamatic/camera/camera_serval.py index 75fb9d5b..f915cdba 100644 --- a/src/instamatic/camera/camera_serval.py +++ b/src/instamatic/camera/camera_serval.py @@ -6,6 +6,7 @@ from serval_toolkit.camera import Camera as ServalCamera from instamatic import config +from instamatic.camera.camera_base import CameraBase logger = logging.getLogger(__name__) @@ -15,16 +16,13 @@ # 3. launch `instamatic` -class CameraServal: +class CameraServal(CameraBase): """Interfaces with Serval from ASI.""" + streamable = True def __init__(self, name='serval'): """Initialize camera module.""" - super().__init__() - - self.name = name - - self.load_defaults() + super().__init__(name) self.establish_connection() @@ -33,14 +31,6 @@ def __init__(self, name='serval'): atexit.register(self.release_connection) - def load_defaults(self): - if self.name != config.settings.camera: - config.load_camera_config(camera_name=self.name) - - self.streamable = True - - self.__dict__.update(config.camera.mapping) - def get_image(self, exposure=None, binsize=None, **kwargs) -> np.ndarray: """Image acquisition routine. If the exposure and binsize are not given, the default values are read from the config file. @@ -108,14 +98,6 @@ def get_image_dimensions(self) -> (int, int): return dim_x, dim_y - def get_camera_dimensions(self) -> (int, int): - """Get the dimensions reported by the camera.""" - return self.dimensions - - def get_name(self) -> str: - """Get the name reported by the camera.""" - return self.name - def establish_connection(self) -> None: """Establish connection to the camera.""" self.conn = ServalCamera() diff --git a/src/instamatic/camera/camera_simu.py b/src/instamatic/camera/camera_simu.py index eada7bc9..da27679f 100644 --- a/src/instamatic/camera/camera_simu.py +++ b/src/instamatic/camera/camera_simu.py @@ -5,24 +5,22 @@ import numpy as np from instamatic import config +from instamatic.camera.camera_base import CameraBase logger = logging.getLogger(__name__) -class CameraSimu: +class CameraSimu(CameraBase): """Simple class that simulates the camera interface and mocks the method calls.""" + streamable = True def __init__(self, name='simulate'): """Initialize camera module.""" - super().__init__() - - self.name = name + super().__init__(name) self.establish_connection() - self.load_defaults() - msg = f'Camera {self.get_name()} initialized' logger.info(msg) @@ -34,14 +32,6 @@ def __init__(self, name='simulate'): self._autoincrement = True self._start_record_time = -1 - def load_defaults(self): - if self.name != config.settings.camera: - config.load_camera_config(camera_name=self.name) - - self.streamable = True - - self.__dict__.update(config.camera.mapping) - def get_image(self, exposure=None, binsize=None, **kwargs) -> np.ndarray: """Image acquisition routine. If the exposure and binsize are not given, the default values are read from the config file. @@ -110,14 +100,6 @@ def get_image_dimensions(self) -> (int, int): return dim_x, dim_y - def get_camera_dimensions(self) -> (int, int): - """Get the dimensions reported by the camera.""" - return self.dimensions - - def get_name(self) -> str: - """Get the name reported by the camera.""" - return self.name - def establish_connection(self) -> None: """Establish connection to the camera.""" res = 1 diff --git a/src/instamatic/camera/camera_timepix.py b/src/instamatic/camera/camera_timepix.py index 9a4178e4..37418499 100644 --- a/src/instamatic/camera/camera_timepix.py +++ b/src/instamatic/camera/camera_timepix.py @@ -10,6 +10,7 @@ import numpy as np from instamatic import config +from instamatic.camera.camera_base import CameraBase from instamatic.utils import high_precision_timers high_precision_timers.enable() @@ -50,8 +51,11 @@ def correct_cross(raw, factor=2.15): raw[:, 258:261] = raw[:, 260:261] / factor -class CameraTPX: +class CameraTPX(CameraBase): + streamable = True + def __init__(self, name='pytimepix'): + super().__init__(name) libdrc = Path(__file__).parent self.lockfile = libdrc / 'timepix.lockfile' @@ -70,12 +74,9 @@ def __init__(self, name='pytimepix'): self.lib.EMCameraObj_timerExpired.restype = c_bool self.obj = self.lib.EMCameraObj_new() - atexit.register(self.disconnect) + atexit.register(self.release_connection) self.is_connected = None - self.name = self.get_name() - self.load_defaults() - def acquire_lock(self): try: os.rename(self.lockfile, self.lockfile) @@ -97,7 +98,7 @@ def uninit(self): """Doesn't do anything.""" self.lib.EMCameraObj_UnInit(self.obj) - def connect(self, hwId): + def establish_connection(self, hwId): hwId = c_int(hwId) ret = self.lib.EMCameraObj_Connect(self.obj, hwId) if ret: @@ -107,7 +108,10 @@ def connect(self, hwId): raise OSError('Could not establish connection to camera.') return ret - def disconnect(self): + def connect(self): + self.establish_connection() + + def release_connection(self): if not self.is_connected: return True ret = self.lib.EMCameraObj_Disconnect(self.obj) @@ -118,6 +122,9 @@ def disconnect(self): print('Camera disconnect failed...') return ret + def disconnect(self): + self.release_connection() + def get_frame_size(self): return self.lib.EMCameraObj_getFrameSize(self.obj) @@ -134,7 +141,7 @@ def read_real_dacs(self, filename): self.lib.EMCameraObj_readRealDacs(self.obj, buffer) except BaseException: traceback.print_exc() - self.disconnect() + self.release_connection() sys.exit() def read_hw_dacs(self, filename): @@ -150,7 +157,7 @@ def read_hw_dacs(self, filename): self.lib.EMCameraObj_readHwDacs(self.obj, buffer) except BaseException: traceback.print_exc() - self.disconnect() + self.release_connection() sys.exit() def read_pixels_cfg(self, filename): @@ -167,7 +174,7 @@ def read_pixels_cfg(self, filename): self.lib.EMCameraObj_readPixelsCfg(self.obj, buffer) except BaseException: traceback.print_exc() - self.disconnect() + self.release_connection() sys.exit() def process_real_dac(self, chipnr=None, index=None, key=None, value=None): @@ -293,19 +300,8 @@ def get_image(self, exposure): def get_name(self): return 'timepix' - def get_camera_dimensions(self) -> (int, int): - return self.dimensions - - def load_defaults(self): - if self.name != config.settings.camera: - config.load_camera_config(camera_name=self.name) - - self.__dict__.update(config.camera.mapping) - - self.streamable = True - -def initialize(config, name='pytimepix'): +def initialize(config, name='pytimepix') -> CameraTPX: from pathlib import Path base = Path(config).parent @@ -324,7 +320,7 @@ def initialize(config, name='pytimepix'): pixelsCfg = base / inp[1] cam = CameraTPX(name=name) - cam.connect(hwId) + cam.establish_connection(hwId) cam.init() @@ -365,8 +361,8 @@ def initialize(config, name='pytimepix'): for x in range(n): cam.acquire_data(t) dt = time.perf_counter() - t0 - print(f'Total time: {dt:.1f} s, acquisition time: {1000*(dt/n):.2f} ms, overhead: {1000*(dt/n - t):.2f} ms') + print(f'Total time: {dt:.1f} s, acquisition time: {1000 * (dt / n):.2f} ms, overhead: {1000 * (dt / n - t):.2f} ms') embed(banner1='') - isDisconnected = cam.disconnect() + isDisconnected = cam.release_connection() diff --git a/src/instamatic/camera/videostream.py b/src/instamatic/camera/videostream.py index 4ed978a1..f4e116ed 100644 --- a/src/instamatic/camera/videostream.py +++ b/src/instamatic/camera/videostream.py @@ -1,6 +1,8 @@ import atexit import threading +from instamatic.camera.camera_base import CameraBase + from .camera import Camera @@ -15,7 +17,7 @@ class ImageGrabber: routine. """ - def __init__(self, cam, callback, frametime: float = 0.05): + def __init__(self, cam: CameraBase, callback, frametime: float = 0.05): super().__init__() self.callback = callback @@ -117,7 +119,7 @@ def send_frame(self, frame, acquire=False): self.frame = frame self.grabber.lock.release() - def setup_grabber(self): + def setup_grabber(self) -> ImageGrabber: grabber = ImageGrabber(self.cam, callback=self.send_frame, frametime=self.frametime) atexit.register(grabber.stop) return grabber diff --git a/src/instamatic/experiments/autocred/experiment.py b/src/instamatic/experiments/autocred/experiment.py index 7e84c0c2..62c7274c 100644 --- a/src/instamatic/experiments/autocred/experiment.py +++ b/src/instamatic/experiments/autocred/experiment.py @@ -25,6 +25,7 @@ ) from instamatic.calibrate.center_z import center_z_height_HYMethod from instamatic.calibrate.filenames import * +from instamatic.experiments.experiment_base import ExperimentBase from instamatic.formats import write_tiff from instamatic.neural_network import predict, preprocess from instamatic.processing.find_crystals import find_crystals_timepix @@ -121,7 +122,7 @@ def load_IS_Calibrations(imageshift, ctrl, diff_defocus, logger, mode): return transform_imgshift, c -class Experiment: +class Experiment(ExperimentBase): def __init__(self, ctrl, exposure_time, exposure_time_image, diff --git a/src/instamatic/experiments/cred/experiment.py b/src/instamatic/experiments/cred/experiment.py index 134aed6b..e134907f 100644 --- a/src/instamatic/experiments/cred/experiment.py +++ b/src/instamatic/experiments/cred/experiment.py @@ -8,6 +8,7 @@ import instamatic from instamatic import config +from instamatic.experiments.experiment_base import ExperimentBase from instamatic.formats import write_tiff from instamatic.processing.ImgConversionTPX import ImgConversionTPX as ImgConversion @@ -23,7 +24,7 @@ def print_and_log(msg, logger=None): logger.info(msg) -class Experiment: +class Experiment(ExperimentBase): """Initialize continuous rotation electron diffraction experiment. ctrl: @@ -149,7 +150,7 @@ def fmt(arr): print(f'Time Period End: {self.t_end}', file=f) print(f'Starting angle: {self.start_angle:.2f} degrees', file=f) print(f'Ending angle: {self.end_angle:.2f} degrees', file=f) - print(f'Rotation range: {self.end_angle-self.start_angle:.2f} degrees', file=f) + print(f'Rotation range: {self.end_angle - self.start_angle:.2f} degrees', file=f) print(f'Exposure Time: {self.exposure:.3f} s', file=f) print(f'Acquisition time: {self.acquisition_time:.3f} s', file=f) print(f'Total time: {self.total_time:.3f} s', file=f) diff --git a/src/instamatic/experiments/cred_gatan/experiment.py b/src/instamatic/experiments/cred_gatan/experiment.py index b9b2dc8d..20ca784b 100644 --- a/src/instamatic/experiments/cred_gatan/experiment.py +++ b/src/instamatic/experiments/cred_gatan/experiment.py @@ -9,10 +9,11 @@ import instamatic from instamatic import config +from instamatic.experiments.experiment_base import ExperimentBase from instamatic.formats import write_tiff -class Experiment: +class Experiment(ExperimentBase): """Class to control data collection through DM to collect continuous rotation electron diffraction data. @@ -97,7 +98,7 @@ def prepare_tracking(self): y_offset = int(self.track_func(start_angle)) x_offset = self.x_offset - print(f'(autotracking) setting a={start_angle:.0f}, x={self.start_x+x_offset:.0f}, y={self.start_y+y_offset:.0f}, z={self.start_z:.0f}') + print(f'(autotracking) setting a={start_angle:.0f}, x={self.start_x + x_offset:.0f}, y={self.start_y + y_offset:.0f}, z={self.start_z:.0f}') self.ctrl.stage.set_xy_with_backlash_correction(x=self.start_x + x_offset, y=self.start_y + y_offset, step=10000) self.ctrl.stage.set(a=start_angle, z=self.start_z) @@ -343,7 +344,7 @@ def log_end_status(self): print(f'Number of frames: {self.nframes}', file=f) print(f'Starting angle: {self.start_angle:.2f} degrees', file=f) print(f'Ending angle: {self.end_angle:.2f} degrees', file=f) - print(f'Rotation range: {self.end_angle-self.start_angle:.2f} degrees', file=f) + print(f'Rotation range: {self.end_angle - self.start_angle:.2f} degrees', file=f) print(f'Rotation speed: {self.rotation_speed:.3f} degrees/s', file=f) # print(f"Exposure Time: {self.timings.exposure_time:.3f} s", file=f) # print(f"Acquisition time: {self.timings.acquisition_time:.3f} s", file=f) diff --git a/src/instamatic/experiments/cred_tvips/experiment.py b/src/instamatic/experiments/cred_tvips/experiment.py index 7fd95a8c..1412b714 100644 --- a/src/instamatic/experiments/cred_tvips/experiment.py +++ b/src/instamatic/experiments/cred_tvips/experiment.py @@ -10,6 +10,7 @@ import instamatic from instamatic import config +from instamatic.experiments.experiment_base import ExperimentBase from instamatic.formats import write_tiff from instamatic.tools import get_acquisition_time @@ -139,11 +140,11 @@ def run_from_tracking_file(self): t1 = time.perf_counter() dt = t1 - t0 print(f'Serial experiment finished -> {n_measured} crystals measured') - print(f'Time taken: {dt:.1f} s, {dt/n_measured:.1f} s/crystal') + print(f'Time taken: {dt:.1f} s, {dt / n_measured:.1f} s/crystal') print(f'Data directory: {self.path}') -class Experiment: +class Experiment(ExperimentBase): """Class to control data collection through EMMenu to collect continuous rotation electron diffraction data. @@ -228,7 +229,7 @@ def prepare_tracking(self): y_offset = int(self.track_func(start_angle)) x_offset = self.x_offset - print(f'(autotracking) setting a={start_angle:.0f}, x={self.start_x+x_offset:.0f}, y={self.start_y+y_offset:.0f}, z={self.start_z:.0f}') + print(f'(autotracking) setting a={start_angle:.0f}, x={self.start_x + x_offset:.0f}, y={self.start_y + y_offset:.0f}, z={self.start_z:.0f}') self.ctrl.stage.set_xy_with_backlash_correction(x=self.start_x + x_offset, y=self.start_y + y_offset, step=10000) self.ctrl.stage.set(a=start_angle, z=self.start_z) @@ -269,6 +270,9 @@ def get_ready(self): print('Ready...') + def setup(self): + self.get_ready() + def manual_activation(self) -> float: ACTIVATION_THRESHOLD = 0.2 @@ -490,7 +494,7 @@ def log_end_status(self): print(f'Number of frames: {self.nframes}', file=f) print(f'Starting angle: {self.start_angle:.2f} degrees', file=f) print(f'Ending angle: {self.end_angle:.2f} degrees', file=f) - print(f'Rotation range: {self.end_angle-self.start_angle:.2f} degrees', file=f) + print(f'Rotation range: {self.end_angle - self.start_angle:.2f} degrees', file=f) print(f'Rotation speed: {self.rotation_speed:.3f} degrees/s', file=f) print(f'Exposure Time: {self.timings.exposure_time:.3f} s', file=f) print(f'Acquisition time: {self.timings.acquisition_time:.3f} s', file=f) diff --git a/src/instamatic/experiments/experiment_base.py b/src/instamatic/experiments/experiment_base.py new file mode 100644 index 00000000..be43b467 --- /dev/null +++ b/src/instamatic/experiments/experiment_base.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod + + +class ExperimentBase(ABC): + """Experiment base class.""" + + @abstractmethod + def start_collection(self): + pass + + def setup(self): + pass + + def teardown(self): + pass + + def __enter__(self): + self.setup() + return self + + def __exit__(self, kind, value, traceback): + self.teardown() diff --git a/src/instamatic/experiments/red/experiment.py b/src/instamatic/experiments/red/experiment.py index 028e78e5..ca3ed72f 100644 --- a/src/instamatic/experiments/red/experiment.py +++ b/src/instamatic/experiments/red/experiment.py @@ -7,11 +7,12 @@ from tqdm.auto import tqdm from instamatic import config +from instamatic.experiments.experiment_base import ExperimentBase from instamatic.formats import write_tiff from instamatic.processing.ImgConversionTPX import ImgConversionTPX as ImgConversion -class Experiment: +class Experiment(ExperimentBase): """Initialize stepwise rotation electron diffraction experiment. ctrl: @@ -141,7 +142,7 @@ def finalize(self): self.stretch_amplitude = config.camera.stretch_amplitude with open(self.path / 'summary.txt', 'a') as f: - print(f'Rotation range: {self.end_angle-self.start_angle:.2f} degrees', file=f) + print(f'Rotation range: {self.end_angle - self.start_angle:.2f} degrees', file=f) print(f'Exposure Time: {self.exposure_time:.3f} s', file=f) print(f'Spot Size: {self.spotsize}', file=f) print(f'Camera length: {self.camera_length} mm', file=f) @@ -184,6 +185,9 @@ def finalize(self): return True + def teardown(self): + self.finalize() + def main(): from instamatic import TEMController diff --git a/src/instamatic/experiments/serialed/experiment.py b/src/instamatic/experiments/serialed/experiment.py index 6ccd3a63..8155de31 100644 --- a/src/instamatic/experiments/serialed/experiment.py +++ b/src/instamatic/experiments/serialed/experiment.py @@ -9,6 +9,7 @@ from instamatic import config from instamatic.calibrate import CalibBeamShift, CalibDirectBeam +from instamatic.experiments.experiment_base import ExperimentBase from instamatic.formats import * from instamatic.processing.find_crystals import find_crystals, find_crystals_timepix from instamatic.processing.flatfield import apply_flatfield_correction, remove_deadpixels @@ -131,7 +132,7 @@ def get_offsets_in_scan_area(box_x, box_y=0, radius=75, padding=2, k=1.0, angle= return np.vstack((x_offsets, y_offsets)).T -class Experiment: +class Experiment(ExperimentBase): """Data collection protocol for serial electron diffraction. Related publication: J. Appl. Cryst. (2018). 51, 1262-1273 @@ -541,6 +542,9 @@ def run(self, ctrl=None, **kwargs): print('\n\nData collection finished.') + def start_collection(self, ctrl=None, **kwargs): + self.run(ctrl=ctrl, **kwargs) + def main(): import argparse diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 4ddaaed4..c1676e33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,14 @@ import os from pathlib import Path +from typing import TYPE_CHECKING import pytest +# Actually importing the controller class breaks the tests +if TYPE_CHECKING: + from instamatic.TEMController.TEMController import TEMController + + base_drc = Path(__file__).parent os.environ['instamatic'] = str(base_drc.absolute()) @@ -12,7 +18,7 @@ def pytest_configure(): @pytest.fixture(scope='module') -def ctrl(): +def ctrl() -> 'TEMController': from instamatic.TEMController import initialize ctrl = initialize() diff --git a/tests/mock/__init__.py b/tests/mock/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mock/camera.py b/tests/mock/camera.py new file mode 100644 index 00000000..5926921f --- /dev/null +++ b/tests/mock/camera.py @@ -0,0 +1,91 @@ +from instamatic.camera.camera_base import CameraBase +from instamatic.camera.camera_emmenu import CameraEMMENU +from instamatic.camera.camera_gatan import CameraDLL +from instamatic.camera.camera_gatan2 import CameraGatan2 +from instamatic.camera.camera_merlin import CameraMerlin +from instamatic.camera.camera_simu import CameraSimu +from instamatic.camera.camera_timepix import CameraTPX + +try: + from instamatic.camera.camera_serval import CameraServal +except ImportError: + + class CameraServal: + pass + + +from instamatic import config + +from .socket import GatanSocketMock, SockMock + +__all__ = [ + 'CameraDLLMock', + 'CameraEMMENUMock', + 'CameraGatan2Mock', + 'CameraMerlinMock', + 'CameraServalMock', + 'CameraSimuMock', + 'CameraTPXMock', + 'CameraMock', +] + + +class CameraMockBase: + """Override `load_defaults` as the config file for each microscope is not + present.""" + + def load_defaults(self): + for key, val in config.camera.mapping.items(): + setattr(self, key, val) + + +class CameraGatan2Mock(CameraMockBase, CameraGatan2): + def __init__(self, name: str = 'gatan2'): + self.name = name + self.g = GatanSocketMock(port='') + self._recording = False + + +class CameraMerlinMock(CameraMockBase, CameraMerlin): + host = '127.0.0.1' + commandport = 0 + dataport = 1 + + def establish_connection(self) -> None: + self.s_cmd = SockMock() + + def establish_data_connection(self) -> None: + self.s_data = SockMock() + + +class CameraSimuMock(CameraMockBase, CameraSimu): + pass + + +class CameraDLLMock(CameraMockBase, CameraDLL): + def establish_connection(self) -> None: + # The connection opens a window. + # Currently no simple way to close it, as it halts execution. + # return super().establish_connection() + raise NotImplementedError() + + +class CameraTPXMock(CameraMockBase, CameraTPX): + def acquire_lock(self): + from pathlib import Path + self.lockfile = Path(__file__).with_name('timepix_mock.lockfile') + return super().acquire_lock() + + +class CameraEMMENUMock(CameraMockBase, CameraEMMENU): + def __init__(self, *args, **kwargs): + raise NotImplementedError() + + +class CameraServalMock(CameraMockBase, CameraServal): + def __init__(self, *args, **kwargs): + raise NotImplementedError() + + +class CameraMock(CameraMockBase, CameraBase): + pass # Will raise a error as not all abstract methods are implemented diff --git a/tests/mock/socket.py b/tests/mock/socket.py new file mode 100644 index 00000000..20e26eb0 --- /dev/null +++ b/tests/mock/socket.py @@ -0,0 +1,34 @@ +from instamatic.camera.gatansocket3 import GatanSocket + + +class SockMock: + """Class to mock a socket connection.""" + + def __init__(self): + self.sent = [] + + def reset(self): + self.sent.clear() + + def shutdown(self, how: int) -> None: + self.reset() + + def close(self) -> None: + self.reset() + + def connect(self, address) -> None: + self.reset() + + def disconnect(self) -> None: + self.reset() + + def sendall(self, data) -> None: + self.sent.append(data) + + def recv(self, bufsize: int) -> bytes: + return bytes([0] * bufsize) + + +class GatanSocketMock(GatanSocket): + def connect(self): + self.sock = SockMock() diff --git a/tests/mock/timepix_mock.lockfile b/tests/mock/timepix_mock.lockfile new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_camera.py b/tests/test_camera.py index e3fde796..ed717152 100644 --- a/tests/test_camera.py +++ b/tests/test_camera.py @@ -1,3 +1,24 @@ +import pytest + +from instamatic.camera.camera_base import CameraBase +from instamatic.camera.camera_emmenu import CameraEMMENU +from instamatic.camera.camera_gatan import CameraDLL +from instamatic.camera.camera_gatan2 import CameraGatan2 +from instamatic.camera.camera_merlin import CameraMerlin +from instamatic.camera.camera_simu import CameraSimu +from instamatic.camera.camera_timepix import CameraTPX + +from .mock.camera import ( + CameraDLLMock, + CameraEMMENUMock, + CameraGatan2Mock, + CameraMerlinMock, + CameraServalMock, + CameraSimuMock, + CameraTPXMock, +) + + def test_get_image(ctrl): bin1 = 1 bin2 = 2 @@ -22,3 +43,36 @@ def test_functions(ctrl): dims = ctrl.cam.get_image_dimensions() assert isinstance(dims, tuple) assert len(dims) == 2 + + +@pytest.mark.parametrize( + 'cam', + [ + pytest.param(CameraDLLMock, marks=pytest.mark.xfail(reason='establish_connection opens a popup window which halts execution')), + pytest.param(CameraEMMENUMock, marks=pytest.mark.xfail(reason='Not implemented')), + CameraGatan2Mock, + CameraMerlinMock, + pytest.param(CameraServalMock, marks=pytest.mark.xfail(reason='Not implemented')), + CameraSimuMock, + CameraTPXMock, + ], +) +def test_init_mock(cam): + c = cam() + + +@pytest.mark.parametrize( + 'cam', + [ + CameraSimu, + pytest.param(CameraDLL, marks=pytest.mark.xfail(reason='Needs config')), + pytest.param(CameraGatan2, marks=pytest.mark.xfail(reason='Needs config + server')), + CameraTPX, + pytest.param(CameraEMMENU, marks=pytest.mark.xfail(reason='WinError: Invalid class string')), + pytest.param(CameraMerlin, marks=pytest.mark.xfail(reason='Needs config + server')), + ] +) +def test_init(cam): + # Use "test" as the name of the camera, as this is where the settings are read from + c = cam(name='test') + assert isinstance(c, CameraBase) diff --git a/tests/test_experiments.py b/tests/test_experiments.py index b8d0bff1..6f97cefd 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -1,90 +1,86 @@ -import tempfile import threading +from pathlib import Path from unittest.mock import MagicMock - -def test_cred(ctrl): - """This one is difficult to test with threads and events.""" - from instamatic.experiments import cred - - stopEvent = threading.Event() - stopEvent.set() - - tempdrc = tempfile.TemporaryDirectory() - expdir = tempdrc.name - - logger = MagicMock() - - cexp = cred.experiment.Experiment( - ctrl, - path=expdir, - stop_event=stopEvent, - log=logger, - mode='simulate', - ) - cexp.start_collection() - - tempdrc.cleanup() - - -def test_cred_tvips(ctrl): - from instamatic.experiments import cRED_tvips - - tempdrc = tempfile.TemporaryDirectory() - expdir = tempdrc.name - - logger = MagicMock() - - ctrl.stage.a = 20 - target_angle = -20 - exposure = 0.1 - manual_control = False - mode = 'diff' - - exp = cRED_tvips.Experiment( - ctrl=ctrl, - path=expdir, - log=logger, - mode=mode, - track=None, - exposure=exposure, - ) - exp.get_ready() - - exp.start_collection( - target_angle=target_angle, - manual_control=manual_control, - ) - - tempdrc.cleanup() - - -def test_red(ctrl): - from instamatic.experiments import RED - - tempdrc = tempfile.TemporaryDirectory() - expdir = tempdrc.name +import pytest + +from instamatic.experiments import RED, cRED, cRED_tvips +from instamatic.experiments.experiment_base import ExperimentBase + + +def test_autoCRED(ctrl): + from instamatic.experiments import autocRED + assert issubclass(autocRED.Experiment, ExperimentBase) + exp = autocRED.Experiment(ctrl, *[None] * 17, path=Path()) + pytest.xfail('Too complex to test at this point') + + +def test_serialED(ctrl): + from instamatic.experiments import serialED + assert issubclass(serialED.Experiment, ExperimentBase) + with pytest.raises(OSError): + exp = serialED.Experiment(ctrl, {}) + pytest.xfail('TODO') + + +@pytest.mark.parametrize( + ['exp_cls', 'init_kwargs', 'collect_kwargs', 'num_collections'], + [ + ( + cRED.Experiment, + { + 'stop_event': threading.Event(), + 'mode': 'simulate', + }, + {}, + 1, + ), + ( + cRED_tvips.Experiment, + { + 'mode': 'diff', + 'track': None, + 'exposure': 0.1, + }, + { + 'target_angle': -20, + 'manual_control': False, + }, + 1, + ), + ( + RED.Experiment, + { + 'flatfield': None, + }, + { + 'exposure_time': 0.01, + 'tilt_range': 5, + 'stepsize': 1.0, + }, + 2, + ), + ], +) +def test_experiment( + exp_cls: 'type[ExperimentBase]', + init_kwargs: dict, + collect_kwargs: dict, + num_collections: int, + ctrl, + tmp_path +): + init_kwargs['ctrl'] = ctrl + + init_kwargs['path'] = tmp_path logger = MagicMock() + init_kwargs['log'] = logger - exposure_time = 0.01 - tilt_range = 5 - stepsize = 1.0 - - red_exp = RED.Experiment( - ctrl=ctrl, - path=expdir, - log=logger, - flatfield=None, - ) - - for x in range(2): - red_exp.start_collection( - exposure_time=exposure_time, - tilt_range=tilt_range, - stepsize=stepsize, - ) - - red_exp.finalize() + stop_event = init_kwargs.get('stop_event') + if stop_event is not None: + stop_event.set() - tempdrc.cleanup() + with exp_cls(**init_kwargs) as exp: + for _ in range(num_collections): + exp.start_collection(**collect_kwargs) diff --git a/tests/test_formats.py b/tests/test_formats.py index 756e0c54..c12df53c 100644 --- a/tests/test_formats.py +++ b/tests/test_formats.py @@ -1,4 +1,5 @@ import os +from contextlib import nullcontext as does_not_raise import numpy as np import pytest @@ -18,67 +19,74 @@ def header(): return {'value': 123, 'string': 'test'} -def test_tiff(data, header): - out = 'out.tiff' +@pytest.fixture(scope='module') +def temp_data_file(tmp_path_factory) -> str: + return str(tmp_path_factory.mktemp('data', numbered=True) / 'out.') - formats.write_tiff(out, data, header) - assert os.path.exists(out) - - img, h = formats.read_image(out) - - assert np.allclose(img, data) - assert header == h - - -def test_cbf(data, header): - out = 'out.cbf' - - formats.write_cbf(out, data, header) - - assert os.path.exists(out) - - # Reader Not implemented: - with pytest.raises(NotImplementedError): - img, h = formats.read_image(out) - - -def test_mrc(data, header): - out = 'out.mrc' +@pytest.mark.parametrize( + ['format', 'write_func', 'with_header'], + [ + ('tiff', formats.write_tiff, True), + ('cbf', formats.write_cbf, True), + ('mrc', formats.write_mrc, False), + ('smv', formats.write_adsc, True), + ('img', formats.write_adsc, True), + ('h5', formats.write_hdf5, True), + ], +) +def test_write(format, write_func, with_header, data, header, temp_data_file): + out = temp_data_file + format - # Header not supported - formats.write_mrc(out, data) + write_func(out, data, header if with_header else None) assert os.path.exists(out) - img, h = formats.read_image(out) - assert np.allclose(img, data) - assert isinstance(header, dict) +@pytest.mark.parametrize( + ['format', 'write_func', 'alt'], + [ + ('tif', formats.write_tiff, 'tiff'), + ('hdf5', formats.write_hdf5, 'h5'), + ], +) +def test_write_rename_ext(format, write_func, alt, data, header, temp_data_file): + out = temp_data_file + 'alt.' + format + out_alt = temp_data_file + 'alt.' + alt + + write_func(out, data, header) + + assert os.path.exists(out_alt) + + +@pytest.mark.parametrize( + ['format', 'write_func', 'with_header', 'raises'], + [ + + ('tiff', formats.write_tiff, True, does_not_raise()), + ('smv', formats.write_adsc, True, does_not_raise()), + ('img', formats.write_adsc, True, does_not_raise()), + ('h5', formats.write_hdf5, True, does_not_raise()), + # Header is not supported + ('mrc', formats.write_mrc, False, pytest.raises(ValueError, match='Header mismatch')), + ('cbf', formats.write_cbf, True, pytest.raises(NotImplementedError)), + ('invalid_extension', lambda *args: None, False, pytest.raises(OSError)), + ('does_not_exist.h5', lambda *args: None, False, pytest.raises(FileNotFoundError)), + ], +) +def test_read(format, write_func, with_header, raises, data, header, temp_data_file): + # Generate file + out = temp_data_file + format + write_func(out, data, header if with_header else None) + + with raises: + out = temp_data_file + format + img, h = formats.read_image(out) -def test_smv(data, header): - out = 'out.smv' - - formats.write_adsc(out, data, header) - - assert os.path.exists(out) - - img, h = formats.read_image(out) - - assert np.allclose(img, data) - assert 'value' in h # changes type to str - assert h['string'] == header['string'] - - -def test_hdf5(data, header): - out = 'out.h5' - - formats.write_hdf5(out, data, header) - - assert os.path.exists(out) - - img, h = formats.read_image(out) + assert np.allclose(img, data) + assert isinstance(h, dict) - assert np.allclose(img, data) - assert header == h + # Check if the header we want is in the header we read + if not all(str(v) == str(h.get(k)) for k, v in header.items()): + raise ValueError('Header mismatch')