Skip to content

Commit 593cb24

Browse files
HDBSCAN link added
1 parent 23898fe commit 593cb24

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold
5151
:param workspace: Workspace name or the object of Workspace from WISH, ex: "WISH0042730"
5252
:param output_ws_name: Name of the peaks workspace
5353
:param conf_threshold: Confidence threshold to filter peaks inferred from RCNN
54-
:param clustering: name of clustering method. Default is QLab and allowed
54+
:param clustering: name of clustering method(QLab or HDBSCAN). Default is QLab
5555
:param kwargs: variable keyword params for clustering methods
5656
"""
5757
start_time = time.time()
@@ -88,7 +88,15 @@ def _do_peak_clustering(self, detected_peaks, clustering, **kwargs):
8888

8989

9090
def _do_hdbscan_clustering(self, peakdata, keep_ignored_labels=True, **kwargs):
91-
data = np.delete(peakdata, [3,4], axis=1)
91+
"""
92+
Do HDBSCAN clustering over the inferred peak coordinates
93+
:param peakata: np array containig the inferred peak coordinates
94+
:param keep_ignored_labels: whether to include the unclustered peaks in final result.
95+
default is True, can be set to False via passing "keep_ignored_labels": False in kwargs
96+
:param kwargs: variable keyword params to be passed to HDBSCAN algorithm
97+
https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html
98+
"""
99+
peak_indices = np.delete(peakdata, [3,4], axis=1)
92100
if ("keep_ignored_labels" in kwargs):
93101
keep_ignored_labels = kwargs.pop("keep_ignored_labels")
94102

@@ -101,17 +109,17 @@ def _do_hdbscan_clustering(self, peakdata, keep_ignored_labels=True, **kwargs):
101109
}
102110
hdbscan_params.update(kwargs)
103111
hdbscan = HDBSCAN(**hdbscan_params)
104-
hdbscan.fit(data)
105-
print(f"Silhouette score of the clusters={silhouette_score(data, hdbscan.labels_)}")
112+
hdbscan.fit(peak_indices)
113+
print(f"Silhouette score of the clusters={silhouette_score(peak_indices, hdbscan.labels_)}")
106114

107115
if keep_ignored_labels:
108-
selected_peaks = np.concatenate((hdbscan.medoids_, data[np.where(hdbscan.labels_==-1)]), axis=0)
116+
selected_peak_indices = np.concatenate((hdbscan.medoids_, peak_indices[np.where(hdbscan.labels_==-1)]), axis=0)
109117
else:
110-
selected_peaks = hdbscan.medoids_
118+
selected_peak_indices = hdbscan.medoids_
111119
confidence = []
112-
for peak in selected_peaks:
113-
confidence.append(peakdata[np.where((data == peak).all(axis=1))[0].item(), -1])
114-
return np.column_stack((selected_peaks, confidence))
120+
for peak in selected_peak_indices:
121+
confidence.append(peakdata[np.where((peak_indices == peak).all(axis=1))[0].item(), -1])
122+
return np.column_stack((selected_peak_indices, confidence))
115123

116124

117125
def _do_cnn_inferencing(self, workspace):
Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
Bragg Peaks detection using a pre-trained Faster RCNN deep neural network
22
================
33

4-
Inorder to use the pre-trained Faster RCNN model inside mantid using an IDAaaS instance, below steps are required.
4+
Inorder to run the pre-trained Faster RCNN model via mantid inside an IDAaaS instance, below steps are required.
55

6-
* Launch an IDAaaS instance with GPUs from WISH > Wish Single Crystal GPU Advanced
7-
* Launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly
6+
* Launch an IDAaaS instance with GPUs selected from WISH > Wish Single Crystal GPU Advanced
7+
* From IDAaaS, launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly
88
* Download `scriptrepository\diffraction\WISH` directory from mantid's script repository as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html
99
* Check whether `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories` of Mantid workbench.
10-
* Below is an example code snippet to test the code. It will create a peaks workspace with the inferred peaks from the cnn. The valid values for the clustering are QLab or HDBSCAN.
10+
* Below is an example code snippet to use the pretrained model for Bragg peak detection. It will create a peaks workspace with the inferred peaks from the model. The valid values for the `clustering` argument are `QLab` or `HDBSCAN`. For `QLab` method the default value of `q_tol=0.05` will be used for `BaseSX.remove_duplicate_peaks_by_qlab` method.
1111
```python
1212
from cnn.BraggDetectCNN import BraggDetectCNN
1313
model_weights = r'/mnt/ceph/auxiliary/wish/BraggDetect_FasterRCNN_Resnet50_Weights_v1.pt'
1414
cnn_peaks_detector = BraggDetectCNN(model_weights_path=model_weights, batch_size=64)
1515
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab")
1616
```
1717
* If the above import is not working, check whether the `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories`.
18+
* Depending on the selected `clustering` method in the above, the user can provide custom parameters using `kwargs` as shown below.
19+
```
20+
kwargs={"q_tol": 0.1}
21+
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab", **kwargs)
22+
23+
or
24+
25+
kwargs={"cluster_selection_method": "leaf", "algorithm": "brute", "keep_ignored_labels": False}
26+
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="HDBSCAN", **kwargs)
27+
```
28+
* The documentation for using HDBSCAN can be found here: https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html
29+
* The documentation for using `BaseSX.remove_duplicate_peaks_by_qlab` can be found here: https://docs.mantidproject.org/nightly/techniques/ISIS_SingleCrystalDiffraction_Workflow.html

0 commit comments

Comments
 (0)