Skip to content

Commit

Permalink
Rebunch is re-done only it hasnt been done before
Browse files Browse the repository at this point in the history
  • Loading branch information
warunawickramasingha committed Jul 15, 2024
1 parent fef4c51 commit 1ebf2f6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion diffraction/WISH/bragg-detect/cnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Inorder to use the pretrained Faster RCNN model inside mantid, below steps are r
* Download the script repository's `scriptrepository\diffraction\WISH` directory as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html
* Check whether `<local path>\diffraction\WISH` path is available at `Python Script Directories` tab from `File->Manage User Directories`.
* Close the workbench
* From command line, change the directory to the place where the scripts were downloaded ex: `<local path>\diffraction\WISH`
* From command line, change the directory to the place where the scripts were downloaded ex: `<local path>\diffraction\WISH\bragg-detect\cnn`
* Within the same conda enviroment, install pytorch dependancies by running `pip install -r requirements.txt`
* Install NVIDIA CUDA Deep Neural Network library (cuDNN) by running `conda install -c anaconda cudnn`
* Re-launch workbench from `workbench` command
Expand Down
11 changes: 9 additions & 2 deletions diffraction/WISH/bragg-detect/cnn/WISHDataSets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from bragg_utils import make_3d_array
from mantid.simpleapi import Load, Rebunch, Workspace
from mantid.simpleapi import Load, Rebunch, Workspace, AnalysisDataService


class WISHWorkspaceDataSet(tc.utils.data.Dataset):
Expand All @@ -12,11 +12,18 @@ def __init__(self, workspace):
if workspace.getAxis(0).getUnit().unitID() != "TOF":
raise RuntimeError("Unit of the X-axis is expected to be TOF")
ws = workspace
ws_name = ws.getName()
elif isinstance(workspace, str):
ws = Load(Filename=workspace, OutputWorkspace=workspace, EnableLogging=False)
ws_name = ws
else:
raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load")
self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace="_cnn_rebunched", StoreInADS=False, EnableLogging=False)

rebunched_ws_name = f"__{ws_name}_cnn_rebunched"
if AnalysisDataService.doesExist(rebunched_ws_name):
self.rebunched_ws = AnalysisDataService[rebunched_ws_name]
else:
self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=rebunched_ws_name, EnableLogging=False)
self.ws_3d = make_3d_array(self.rebunched_ws)
print(f"Data set for {workspace} is created with shape{self.ws_3d.shape}")
self.trans = A.Compose([A.pytorch.transforms.ToTensorV2(p=1.0)])
Expand Down

0 comments on commit 1ebf2f6

Please sign in to comment.