Skip to content

Commit b2eb0fd

Browse files
Merge pull request #17 from mantidproject/38269_peak_clustering
HDBSCAN clustering method added for Bragg peaks inferred from DL model
2 parents 09f1e1c + 593cb24 commit b2eb0fd

File tree

3 files changed

+98
-20
lines changed

3 files changed

+98
-20
lines changed

diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
from tqdm import tqdm
99
from Diffraction.single_crystal.base_sx import BaseSX
1010
import time
11+
from enum import Enum
12+
from sklearn.cluster import HDBSCAN
13+
from sklearn.metrics import silhouette_score
14+
15+
class Clustering(Enum):
16+
QLab = 1
17+
HDBSCAN = 2
18+
1119

1220
class BraggDetectCNN:
1321
"""
@@ -19,7 +27,7 @@ class BraggDetectCNN:
1927
2028
# 2) Create a peaks workspace containing bragg peaks detected with a confidence greater than conf_threshold
2129
cnn_bragg_peaks_detector = BraggDetectCNN(model_weights_path=cnn_weights_path, batch_size=64, workers=0, iou_threshold=0.001)
22-
cnn_bragg_peaks_detector.find_bragg_peaks(workspace="WISH00042730", conf_threshold=0.0, q_tol=0.05)
30+
cnn_bragg_peaks_detector.find_bragg_peaks(workspace="WISH00042730", conf_threshold=0.0, clustering="QLab", q_tol=0.05)
2331
"""
2432

2533
def __init__(self, model_weights_path, batch_size=64, workers=0, iou_threshold=0.001):
@@ -37,28 +45,83 @@ def __init__(self, model_weights_path, batch_size=64, workers=0, iou_threshold=0
3745
self.iou_threshold = iou_threshold
3846

3947

40-
def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05):
48+
def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering=Clustering.QLab.name, **kwargs):
4149
"""
4250
Find bragg peaks using the pre trained FasterRCNN model and create a peaks workspace
4351
:param workspace: Workspace name or the object of Workspace from WISH, ex: "WISH0042730"
4452
:param output_ws_name: Name of the peaks workspace
4553
:param conf_threshold: Confidence threshold to filter peaks inferred from RCNN
46-
:param q_tol: qlab tolerance to remove duplicate peaks
54+
:param clustering: name of clustering method(QLab or HDBSCAN). Default is QLab
55+
:param kwargs: variable keyword params for clustering methods
4756
"""
4857
start_time = time.time()
4958
data_set, predicted_indices = self._do_cnn_inferencing(workspace)
59+
5060
filtered_indices = predicted_indices[predicted_indices[:, -1] > conf_threshold]
51-
filtered_indices_rounded = np.round(filtered_indices[:, :-1]).astype(int)
52-
peaksws = createPeaksWorkspaceFromIndices(data_set.get_workspace(), output_ws_name, filtered_indices_rounded, data_set.get_ws_as_3d_array())
61+
62+
#Do Clustering
63+
print(f"Starting peak clustering with {clustering} method..")
64+
clustered_peaks = self._do_peak_clustering(filtered_indices, clustering, **kwargs)
65+
cluster_indices_rounded = np.round(clustered_peaks[:, :3]).astype(int)
66+
peaksws = createPeaksWorkspaceFromIndices(data_set.get_workspace(), output_ws_name, cluster_indices_rounded, data_set.get_ws_as_3d_array())
5367
for ipk, pk in enumerate(peaksws):
54-
pk.setIntensity(filtered_indices[ipk, -1])
68+
pk.setIntensity(clustered_peaks[ipk, -1])
69+
70+
if clustering == Clustering.QLab.name:
71+
#Filter peaks by qlab
72+
clustering_params = {"q_tol": 0.05 }
73+
clustering_params.update(kwargs)
74+
BaseSX.remove_duplicate_peaks_by_qlab(peaksws, **clustering_params)
75+
76+
print(f"Number of peaks after clustering is = {len(peaksws)}")
5577

56-
#Filter duplicates by qlab
57-
BaseSX.remove_duplicate_peaks_by_qlab(peaksws, q_tol)
5878
data_set.delete_rebunched_ws()
59-
print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time} seconds!")
79+
print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time:.2f} seconds!")
80+
81+
82+
def _do_peak_clustering(self, detected_peaks, clustering, **kwargs):
83+
print(f"Number of peaks before clustering = {len(detected_peaks)}")
84+
if clustering == Clustering.HDBSCAN.name:
85+
return self._do_hdbscan_clustering(detected_peaks, **kwargs)
86+
else:
87+
return detected_peaks
6088

6189

90+
def _do_hdbscan_clustering(self, peakdata, keep_ignored_labels=True, **kwargs):
91+
"""
92+
Do HDBSCAN clustering over the inferred peak coordinates
93+
:param peakata: np array containig the inferred peak coordinates
94+
:param keep_ignored_labels: whether to include the unclustered peaks in final result.
95+
default is True, can be set to False via passing "keep_ignored_labels": False in kwargs
96+
:param kwargs: variable keyword params to be passed to HDBSCAN algorithm
97+
https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html
98+
"""
99+
peak_indices = np.delete(peakdata, [3,4], axis=1)
100+
if ("keep_ignored_labels" in kwargs):
101+
keep_ignored_labels = kwargs.pop("keep_ignored_labels")
102+
103+
hdbscan_params = {"min_cluster_size": 2,
104+
"min_samples": 2,
105+
"store_centers" : "medoid",
106+
"algorithm": "auto",
107+
"cluster_selection_method": "eom",
108+
"metric": "euclidean"
109+
}
110+
hdbscan_params.update(kwargs)
111+
hdbscan = HDBSCAN(**hdbscan_params)
112+
hdbscan.fit(peak_indices)
113+
print(f"Silhouette score of the clusters={silhouette_score(peak_indices, hdbscan.labels_)}")
114+
115+
if keep_ignored_labels:
116+
selected_peak_indices = np.concatenate((hdbscan.medoids_, peak_indices[np.where(hdbscan.labels_==-1)]), axis=0)
117+
else:
118+
selected_peak_indices = hdbscan.medoids_
119+
confidence = []
120+
for peak in selected_peak_indices:
121+
confidence.append(peakdata[np.where((peak_indices == peak).all(axis=1))[0].item(), -1])
122+
return np.column_stack((selected_peak_indices, confidence))
123+
124+
62125
def _do_cnn_inferencing(self, workspace):
63126
data_set = WISHWorkspaceDataSet(workspace)
64127
data_loader = tc.utils.data.DataLoader(data_set, batch_size=self.batch_size, shuffle=False, num_workers=self.workers)
@@ -71,9 +134,14 @@ def _do_cnn_inferencing(self, workspace):
71134
prediction = self.model([img.to(self.device)])[0]
72135
nms_prediction = self._apply_nms(prediction, self.iou_threshold)
73136
for box, score in zip(nms_prediction['boxes'], nms_prediction['scores']):
137+
box = box.cpu().numpy().astype(int)
74138
tof = (box[0]+box[2])/2
75139
tube_res = (box[1]+box[3])/2
76-
predicted_indices_with_score.append([tube_idx, tube_res.cpu(), tof.cpu(), score.cpu()])
140+
141+
boxsum = np.sum(img[0, box[1]:box[3], box[0]:box[2]].numpy())
142+
143+
predicted_indices_with_score.append([tube_idx, tube_res, tof, boxsum, score.cpu()])
144+
77145
return data_set, np.array(predicted_indices_with_score)
78146

79147

@@ -98,7 +166,7 @@ def _select_device(self):
98166

99167
def _load_cnn_model_from_weights(self, weights_path):
100168
model = self._get_fasterrcnn_resnet50_fpn(num_classes=2)
101-
model.load_state_dict(tc.load(weights_path, map_location=self.device))
169+
model.load_state_dict(tc.load(weights_path, map_location=self.device, weights_only=True))
102170
return model.to(self.device)
103171

104172

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
Bragg Peaks detection using a pre-trained Faster RCNN deep neural network
22
================
33

4-
Inorder to use the pre-trained Faster RCNN model inside mantid using an IDAaaS instance, below steps are required.
4+
Inorder to run the pre-trained Faster RCNN model via mantid inside an IDAaaS instance, below steps are required.
55

6-
* Launch an IDAaaS instance with GPUs from WISH > Wish Single Crystal GPU Advanced
7-
* Launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly
6+
* Launch an IDAaaS instance with GPUs selected from WISH > Wish Single Crystal GPU Advanced
7+
* From IDAaaS, launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly
88
* Download `scriptrepository\diffraction\WISH` directory from mantid's script repository as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html
99
* Check whether `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories` of Mantid workbench.
10-
* Below is an example code snippet to test the code. It will create a peaks workspace with the inferred peaks from the cnn and will do a peak filtering using the q_tol provided using `BaseSX.remove_duplicate_peaks_by_qlab`.
10+
* Below is an example code snippet to use the pretrained model for Bragg peak detection. It will create a peaks workspace with the inferred peaks from the model. The valid values for the `clustering` argument are `QLab` or `HDBSCAN`. For `QLab` method the default value of `q_tol=0.05` will be used for `BaseSX.remove_duplicate_peaks_by_qlab` method.
1111
```python
1212
from cnn.BraggDetectCNN import BraggDetectCNN
1313
model_weights = r'/mnt/ceph/auxiliary/wish/BraggDetect_FasterRCNN_Resnet50_Weights_v1.pt'
1414
cnn_peaks_detector = BraggDetectCNN(model_weights_path=model_weights, batch_size=64)
15-
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05)
15+
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab")
1616
```
1717
* If the above import is not working, check whether the `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories`.
18+
* Depending on the selected `clustering` method in the above, the user can provide custom parameters using `kwargs` as shown below.
19+
```
20+
kwargs={"q_tol": 0.1}
21+
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab", **kwargs)
22+
23+
or
24+
25+
kwargs={"cluster_selection_method": "leaf", "algorithm": "brute", "keep_ignored_labels": False}
26+
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="HDBSCAN", **kwargs)
27+
```
28+
* The documentation for using HDBSCAN can be found here: https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html
29+
* The documentation for using `BaseSX.remove_duplicate_peaks_by_qlab` can be found here: https://docs.mantidproject.org/nightly/techniques/ISIS_SingleCrystalDiffraction_Workflow.html
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
-f https://download.pytorch.org/whl/cu118
2-
torch
3-
torchvision
4-
1+
torch==2.5.1
2+
torchvision==0.20.1
53
albumentations==1.4.0
64
tqdm==4.66.3

0 commit comments

Comments
 (0)