Skip to content

Commit aa85faa

Browse files
committed
Add baseline cluster metrics to model summary
1 parent aa8e056 commit aa85faa

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

src/model/clustering.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use comfy_table::{
1111
};
1212
use smartcore::linalg::basic::arrays::{Array1, Array2};
1313
use smartcore::numbers::{basenum::Number, floatnum::FloatNumber, realnum::RealNumber};
14+
use std::collections::BTreeSet;
1415
use std::fmt::{Display, Formatter};
1516

1617
/// Trains clustering models
@@ -53,8 +54,9 @@ where
5354
for algorithm_name in self.settings.selected_algorithms() {
5455
let algorithm = ClusteringAlgorithm::from_name(algorithm_name);
5556
let fitted = algorithm.fit(&self.x_train, &self.settings);
56-
self.trained_algorithms
57-
.push(TrainedClusteringAlgorithm::new(algorithm_name, fitted));
57+
let mut trained = TrainedClusteringAlgorithm::new(algorithm_name, fitted);
58+
trained.compute_baseline(&self.x_train, &self.settings);
59+
self.trained_algorithms.push(trained);
5860
}
5961
}
6062

@@ -132,6 +134,8 @@ where
132134
table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
133135
table.set_header(vec![
134136
Cell::new("Model").add_attribute(Attribute::Bold),
137+
Cell::new("Clusters").add_attribute(Attribute::Bold),
138+
Cell::new("Noise").add_attribute(Attribute::Bold),
135139
Cell::new("Homogeneity").add_attribute(Attribute::Bold),
136140
Cell::new("Completeness").add_attribute(Attribute::Bold),
137141
Cell::new("V-Measure").add_attribute(Attribute::Bold),
@@ -144,6 +148,8 @@ where
144148
"-".to_string(),
145149
"-".to_string(),
146150
"-".to_string(),
151+
"-".to_string(),
152+
"-".to_string(),
147153
]);
148154
}
149155
} else {
@@ -156,6 +162,22 @@ where
156162
}
157163
}
158164

165+
/// Aggregate cluster statistics that do not require ground-truth labels.
166+
#[derive(Debug, Clone, Copy)]
167+
struct ClusterBaseline {
168+
cluster_count: usize,
169+
noise_count: usize,
170+
}
171+
172+
impl ClusterBaseline {
173+
const fn new(cluster_count: usize, noise_count: usize) -> Self {
174+
Self {
175+
cluster_count,
176+
noise_count,
177+
}
178+
}
179+
}
180+
159181
/// Trained clustering algorithm with optional metrics.
160182
struct TrainedClusteringAlgorithm<INPUT, CLUSTER, InputArray, ClusterArray>
161183
where
@@ -167,6 +189,7 @@ where
167189
algorithm_name: ClusteringAlgorithmName,
168190
algorithm: ClusteringAlgorithm<INPUT, CLUSTER, InputArray, ClusterArray>,
169191
metrics: Option<HCVScore<CLUSTER>>,
192+
baseline: Option<ClusterBaseline>,
170193
}
171194

172195
impl<INPUT, CLUSTER, InputArray, ClusterArray>
@@ -185,13 +208,35 @@ where
185208
algorithm_name,
186209
algorithm,
187210
metrics: None,
211+
baseline: None,
188212
}
189213
}
190214

191215
fn predict(&self, x: &InputArray, settings: &ClusteringSettings) -> ModelResult<ClusterArray> {
192216
self.algorithm.predict(x, settings)
193217
}
194218

219+
fn compute_baseline(&mut self, x: &InputArray, settings: &ClusteringSettings) {
220+
let Ok(predictions) = self.predict(x, settings) else {
221+
self.baseline = None;
222+
return;
223+
};
224+
225+
let mut unique_clusters: BTreeSet<CLUSTER> = BTreeSet::new();
226+
let mut noise_count = 0_usize;
227+
228+
for label in predictions.iterator(0) {
229+
let value = *label;
230+
if self.algorithm_name == ClusteringAlgorithmName::DBSCAN && value == CLUSTER::zero() {
231+
noise_count += 1;
232+
} else {
233+
unique_clusters.insert(value);
234+
}
235+
}
236+
237+
self.baseline = Some(ClusterBaseline::new(unique_clusters.len(), noise_count));
238+
}
239+
195240
fn display_row(&self) -> Vec<String> {
196241
let (homogeneity, completeness, v_measure) = if let Some(scores) = &self.metrics {
197242
let format_score = |s: Option<f64>| match s {
@@ -207,8 +252,19 @@ where
207252
("-".to_string(), "-".to_string(), "-".to_string())
208253
};
209254

255+
let (clusters, noise) = if let Some(baseline) = &self.baseline {
256+
(
257+
baseline.cluster_count.to_string(),
258+
baseline.noise_count.to_string(),
259+
)
260+
} else {
261+
("-".to_string(), "-".to_string())
262+
};
263+
210264
vec![
211265
self.algorithm_name.to_string(),
266+
clusters,
267+
noise,
212268
homogeneity,
213269
completeness,
214270
v_measure,

tests/clustering.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ fn clustering_model_display_shows_metrics() {
8989
assert!(output.contains("KMeans"));
9090
assert!(output.contains("Agglomerative"));
9191
assert!(output.contains("DBSCAN"));
92+
assert!(output.contains("Clusters"));
93+
assert!(output.contains("Noise"));
9294
assert!(output.contains("V-Measure"));
9395
assert!(output.contains("1.00"));
9496
}
@@ -126,6 +128,23 @@ fn clustering_model_display_shows_configured_algorithm_when_untrained() {
126128
assert!(output.contains("Homogeneity"));
127129
}
128130

131+
#[test]
132+
fn clustering_model_display_shows_baseline_without_ground_truth() {
133+
// Arrange
134+
let x = clustering_testing_data();
135+
let mut model: ClusteringModel<f64, u8, DenseMatrix<f64>, Vec<u8>> =
136+
ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2));
137+
model.train();
138+
139+
// Act
140+
let output = format!("{model}");
141+
142+
// Assert
143+
assert!(output.contains("Clusters"));
144+
assert!(output.contains("Noise"));
145+
assert!(output.contains('2'));
146+
}
147+
129148
#[test]
130149
fn clustering_model_display_clears_metrics_after_retraining() {
131150
// Arrange

0 commit comments

Comments
 (0)