Skip to content

Commit

Permalink
removed rebunched ws after inferencing
Browse files Browse the repository at this point in the history
  • Loading branch information
warunawickramasingha committed Aug 6, 2024
1 parent 1ebf2f6 commit 1be79bc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold

#Filter duplicates by qlab
BaseSX.remove_duplicate_peaks_by_qlab(peaksws, q_tol)
data_set.delete_rebunched_ws()
print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time} seconds!")


Expand Down
16 changes: 9 additions & 7 deletions diffraction/WISH/bragg-detect/cnn/WISHDataSets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from bragg_utils import make_3d_array
from mantid.simpleapi import Load, Rebunch, Workspace, AnalysisDataService
from mantid.simpleapi import Load, Rebunch, Workspace, DeleteWorkspace


class WISHWorkspaceDataSet(tc.utils.data.Dataset):
Expand All @@ -17,19 +17,21 @@ def __init__(self, workspace):
ws = Load(Filename=workspace, OutputWorkspace=workspace, EnableLogging=False)
ws_name = ws
else:
raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load")
raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load")

rebunched_ws_name = f"__{ws_name}_cnn_rebunched"
if AnalysisDataService.doesExist(rebunched_ws_name):
self.rebunched_ws = AnalysisDataService[rebunched_ws_name]
else:
self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=rebunched_ws_name, EnableLogging=False)
self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=f"__{ws_name}_cnn_rebunched", EnableLogging=False)
self.ws_3d = make_3d_array(self.rebunched_ws)
print(f"Data set for {workspace} is created with shape{self.ws_3d.shape}")
self.trans = A.Compose([A.pytorch.transforms.ToTensorV2(p=1.0)])

def get_workspace(self):
if self.rebunched_ws is None:
raise RuntimeError("Rebunched workspace is not available!")
return self.rebunched_ws

def delete_rebunched_ws(self):
DeleteWorkspace(Workspace=self.rebunched_ws)
self.rebunched_ws = None

def get_ws_as_3d_array(self):
return self.ws_3d
Expand Down

0 comments on commit 1be79bc

Please sign in to comment.