diff --git a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py index 23ba523..61d769f 100644 --- a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py +++ b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py @@ -55,6 +55,7 @@ def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold #Filter duplicates by qlab BaseSX.remove_duplicate_peaks_by_qlab(peaksws, q_tol) + data_set.delete_rebunched_ws() print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time} seconds!") diff --git a/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py b/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py index c0fe447..11b038d 100644 --- a/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py +++ b/diffraction/WISH/bragg-detect/cnn/WISHDataSets.py @@ -3,7 +3,7 @@ import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 from bragg_utils import make_3d_array -from mantid.simpleapi import Load, Rebunch, Workspace, AnalysisDataService +from mantid.simpleapi import Load, Rebunch, Workspace, DeleteWorkspace class WISHWorkspaceDataSet(tc.utils.data.Dataset): @@ -17,19 +17,21 @@ def __init__(self, workspace): ws = Load(Filename=workspace, OutputWorkspace=workspace, EnableLogging=False) ws_name = ws else: - raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load") + raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load") - rebunched_ws_name = f"__{ws_name}_cnn_rebunched" - if AnalysisDataService.doesExist(rebunched_ws_name): - self.rebunched_ws = AnalysisDataService[rebunched_ws_name] - else: - self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=rebunched_ws_name, EnableLogging=False) + self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=f"__{ws_name}_cnn_rebunched", EnableLogging=False) self.ws_3d = make_3d_array(self.rebunched_ws) print(f"Data set for {workspace} is created with shape{self.ws_3d.shape}") self.trans = A.Compose([A.pytorch.transforms.ToTensorV2(p=1.0)]) def get_workspace(self): + if self.rebunched_ws is None: + raise RuntimeError("Rebunched workspace is not available!") return self.rebunched_ws + + def delete_rebunched_ws(self): + DeleteWorkspace(Workspace=self.rebunched_ws) + self.rebunched_ws = None def get_ws_as_3d_array(self): return self.ws_3d