diff --git a/diffraction/WISH/bragg-detect/cnn/README.md b/diffraction/WISH/bragg-detect/cnn/README.md index 633af99..313e6a8 100644 --- a/diffraction/WISH/bragg-detect/cnn/README.md +++ b/diffraction/WISH/bragg-detect/cnn/README.md @@ -9,7 +9,7 @@ Inorder to use the pretrained Faster RCNN model inside mantid, below steps are r * Download the script repository's `scriptrepository\diffraction\WISH` directory as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html * Check whether `\diffraction\WISH` path is available at `Python Script Directories` tab from `File->Manage User Directories`. * Close the workbench -* From command line, change the directory to the place where the scripts were downloaded ex: `\diffraction\WISH` +* From command line, change the directory to the place where the scripts were downloaded ex: `\diffraction\WISH\bragg-detect\cnn` * Within the same conda enviroment, install pytorch dependancies by running `pip install -r requirements.txt` * Install NVIDIA CUDA Deep Neural Network library (cuDNN) by running `conda install -c anaconda cudnn` * Re-launch workbench from `workbench` command diff --git a/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py b/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py index 1ad3048..c0fe447 100644 --- a/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py +++ b/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py @@ -3,7 +3,7 @@ import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 from bragg_utils import make_3d_array -from mantid.simpleapi import Load, Rebunch, Workspace +from mantid.simpleapi import Load, Rebunch, Workspace, AnalysisDataService class WISHWorkspaceDataSet(tc.utils.data.Dataset): @@ -12,11 +12,18 @@ def __init__(self, workspace): if workspace.getAxis(0).getUnit().unitID() != "TOF": raise RuntimeError("Unit of the X-axis is expected to be TOF") ws = workspace + ws_name = ws.getName() elif isinstance(workspace, str): ws = Load(Filename=workspace, OutputWorkspace=workspace, EnableLogging=False) + ws_name = ws else: raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load") - self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace="_cnn_rebunched", StoreInADS=False, EnableLogging=False) + + rebunched_ws_name = f"__{ws_name}_cnn_rebunched" + if AnalysisDataService.doesExist(rebunched_ws_name): + self.rebunched_ws = AnalysisDataService[rebunched_ws_name] + else: + self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=rebunched_ws_name, EnableLogging=False) self.ws_3d = make_3d_array(self.rebunched_ws) print(f"Data set for {workspace} is created with shape{self.ws_3d.shape}") self.trans = A.Compose([A.pytorch.transforms.ToTensorV2(p=1.0)])