From 593cb242edf11dea39a29f4016590847b3b65ced Mon Sep 17 00:00:00 2001
From: Waruna Wickramasingha <waruna.wickramasingha@stfc.ac.uk>
Date: Fri, 3 Jan 2025 15:47:18 +0000
Subject: [PATCH] HDBSCAN link added

---
 .../WISH/bragg-detect/cnn/BraggDetectCNN.py   | 26 ++++++++++++-------
 diffraction/WISH/bragg-detect/cnn/README.md   | 20 +++++++++++---
 2 files changed, 33 insertions(+), 13 deletions(-)

diff --git a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py
index 13974b6..a10992b 100644
--- a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py
+++ b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py
@@ -51,7 +51,7 @@ def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold
         :param workspace: Workspace name or the object of Workspace from WISH, ex: "WISH0042730"
         :param output_ws_name: Name of the peaks workspace
         :param conf_threshold: Confidence threshold to filter peaks inferred from RCNN
-        :param clustering: name of clustering method. Default is QLab and allowed
+        :param clustering: name of clustering method(QLab or HDBSCAN). Default is QLab
         :param kwargs: variable keyword params for clustering methods
         """
         start_time = time.time()
@@ -88,7 +88,15 @@ def _do_peak_clustering(self, detected_peaks, clustering, **kwargs):
 
 
     def _do_hdbscan_clustering(self, peakdata, keep_ignored_labels=True, **kwargs):
-        data = np.delete(peakdata, [3,4], axis=1)
+        """
+        Do HDBSCAN clustering over the inferred peak coordinates
+        :param peakata: np array containig the inferred peak coordinates
+        :param keep_ignored_labels: whether to include the unclustered peaks in final result.
+            default is True, can be set to False via passing "keep_ignored_labels": False in kwargs
+        :param kwargs: variable keyword params to be passed to HDBSCAN algorithm 
+            https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html
+        """
+        peak_indices = np.delete(peakdata, [3,4], axis=1)
         if ("keep_ignored_labels" in kwargs):
             keep_ignored_labels = kwargs.pop("keep_ignored_labels")
 
@@ -101,17 +109,17 @@ def _do_hdbscan_clustering(self, peakdata, keep_ignored_labels=True, **kwargs):
                           }
         hdbscan_params.update(kwargs)
         hdbscan = HDBSCAN(**hdbscan_params)
-        hdbscan.fit(data)
-        print(f"Silhouette score of the clusters={silhouette_score(data, hdbscan.labels_)}")
+        hdbscan.fit(peak_indices)
+        print(f"Silhouette score of the clusters={silhouette_score(peak_indices, hdbscan.labels_)}")
 
         if keep_ignored_labels:
-            selected_peaks = np.concatenate((hdbscan.medoids_, data[np.where(hdbscan.labels_==-1)]), axis=0)
+            selected_peak_indices = np.concatenate((hdbscan.medoids_, peak_indices[np.where(hdbscan.labels_==-1)]), axis=0)
         else:
-            selected_peaks = hdbscan.medoids_
+            selected_peak_indices = hdbscan.medoids_
         confidence = []
-        for peak in selected_peaks:
-            confidence.append(peakdata[np.where((data == peak).all(axis=1))[0].item(), -1])
-        return np.column_stack((selected_peaks, confidence))
+        for peak in selected_peak_indices:
+            confidence.append(peakdata[np.where((peak_indices == peak).all(axis=1))[0].item(), -1])
+        return np.column_stack((selected_peak_indices, confidence))
     
 
     def _do_cnn_inferencing(self, workspace):
diff --git a/diffraction/WISH/bragg-detect/cnn/README.md b/diffraction/WISH/bragg-detect/cnn/README.md
index 6760580..f374367 100644
--- a/diffraction/WISH/bragg-detect/cnn/README.md
+++ b/diffraction/WISH/bragg-detect/cnn/README.md
@@ -1,13 +1,13 @@
 Bragg Peaks detection using a pre-trained Faster RCNN deep neural network 
 ================
 
-Inorder to use the pre-trained Faster RCNN model inside mantid using an IDAaaS instance, below steps are required.
+Inorder to run the pre-trained Faster RCNN model via mantid inside an IDAaaS instance, below steps are required.
 
-* Launch an IDAaaS instance with GPUs from WISH > Wish Single Crystal GPU Advanced
-* Launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly 
+* Launch an IDAaaS instance with GPUs selected from WISH > Wish Single Crystal GPU Advanced
+* From IDAaaS, launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly 
 * Download `scriptrepository\diffraction\WISH` directory from mantid's script repository as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html
 * Check whether `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories` of Mantid workbench.
-* Below is an example code snippet to test the code. It will create a peaks workspace with the inferred peaks from the cnn. The valid values for the clustering are QLab or HDBSCAN.
+* Below is an example code snippet to use the pretrained model for Bragg peak detection. It will create a peaks workspace with the inferred peaks from the model. The valid values for the `clustering` argument are `QLab` or `HDBSCAN`. For `QLab` method the default value of `q_tol=0.05` will be used for `BaseSX.remove_duplicate_peaks_by_qlab` method. 
 ```python
 from cnn.BraggDetectCNN import BraggDetectCNN
 model_weights = r'/mnt/ceph/auxiliary/wish/BraggDetect_FasterRCNN_Resnet50_Weights_v1.pt'
@@ -15,3 +15,15 @@ cnn_peaks_detector = BraggDetectCNN(model_weights_path=model_weights, batch_size
 cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab")
 ```
 * If the above import is not working, check whether the `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories`.
+* Depending on the selected `clustering` method in the above, the user can provide custom parameters using `kwargs` as shown below.
+```
+kwargs={"q_tol": 0.1} 
+cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab", **kwargs)
+
+or 
+
+kwargs={"cluster_selection_method": "leaf", "algorithm": "brute", "keep_ignored_labels": False} 
+cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="HDBSCAN", **kwargs)
+```
+* The documentation for using HDBSCAN can be found here: https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html
+* The documentation for using `BaseSX.remove_duplicate_peaks_by_qlab` can be found here: https://docs.mantidproject.org/nightly/techniques/ISIS_SingleCrystalDiffraction_Workflow.html