@@ -11,6 +11,7 @@ use comfy_table::{
1111} ;
1212use smartcore:: linalg:: basic:: arrays:: { Array1 , Array2 } ;
1313use smartcore:: numbers:: { basenum:: Number , floatnum:: FloatNumber , realnum:: RealNumber } ;
14+ use std:: collections:: BTreeSet ;
1415use std:: fmt:: { Display , Formatter } ;
1516
1617/// Trains clustering models
5354 for algorithm_name in self . settings . selected_algorithms ( ) {
5455 let algorithm = ClusteringAlgorithm :: from_name ( algorithm_name) ;
5556 let fitted = algorithm. fit ( & self . x_train , & self . settings ) ;
56- self . trained_algorithms
57- . push ( TrainedClusteringAlgorithm :: new ( algorithm_name, fitted) ) ;
57+ let mut trained = TrainedClusteringAlgorithm :: new ( algorithm_name, fitted) ;
58+ trained. compute_baseline ( & self . x_train , & self . settings ) ;
59+ self . trained_algorithms . push ( trained) ;
5860 }
5961 }
6062
@@ -132,6 +134,8 @@ where
132134 table. apply_modifier ( UTF8_SOLID_INNER_BORDERS ) ;
133135 table. set_header ( vec ! [
134136 Cell :: new( "Model" ) . add_attribute( Attribute :: Bold ) ,
137+ Cell :: new( "Clusters" ) . add_attribute( Attribute :: Bold ) ,
138+ Cell :: new( "Noise" ) . add_attribute( Attribute :: Bold ) ,
135139 Cell :: new( "Homogeneity" ) . add_attribute( Attribute :: Bold ) ,
136140 Cell :: new( "Completeness" ) . add_attribute( Attribute :: Bold ) ,
137141 Cell :: new( "V-Measure" ) . add_attribute( Attribute :: Bold ) ,
@@ -144,6 +148,8 @@ where
144148 "-" . to_string( ) ,
145149 "-" . to_string( ) ,
146150 "-" . to_string( ) ,
151+ "-" . to_string( ) ,
152+ "-" . to_string( ) ,
147153 ] ) ;
148154 }
149155 } else {
@@ -156,6 +162,22 @@ where
156162 }
157163}
158164
165+ /// Aggregate cluster statistics that do not require ground-truth labels.
166+ #[ derive( Debug , Clone , Copy ) ]
167+ struct ClusterBaseline {
168+ cluster_count : usize ,
169+ noise_count : usize ,
170+ }
171+
172+ impl ClusterBaseline {
173+ const fn new ( cluster_count : usize , noise_count : usize ) -> Self {
174+ Self {
175+ cluster_count,
176+ noise_count,
177+ }
178+ }
179+ }
180+
159181/// Trained clustering algorithm with optional metrics.
160182struct TrainedClusteringAlgorithm < INPUT , CLUSTER , InputArray , ClusterArray >
161183where
@@ -167,6 +189,7 @@ where
167189 algorithm_name : ClusteringAlgorithmName ,
168190 algorithm : ClusteringAlgorithm < INPUT , CLUSTER , InputArray , ClusterArray > ,
169191 metrics : Option < HCVScore < CLUSTER > > ,
192+ baseline : Option < ClusterBaseline > ,
170193}
171194
172195impl < INPUT , CLUSTER , InputArray , ClusterArray >
@@ -185,13 +208,35 @@ where
185208 algorithm_name,
186209 algorithm,
187210 metrics : None ,
211+ baseline : None ,
188212 }
189213 }
190214
191215 fn predict ( & self , x : & InputArray , settings : & ClusteringSettings ) -> ModelResult < ClusterArray > {
192216 self . algorithm . predict ( x, settings)
193217 }
194218
219+ fn compute_baseline ( & mut self , x : & InputArray , settings : & ClusteringSettings ) {
220+ let Ok ( predictions) = self . predict ( x, settings) else {
221+ self . baseline = None ;
222+ return ;
223+ } ;
224+
225+ let mut unique_clusters: BTreeSet < CLUSTER > = BTreeSet :: new ( ) ;
226+ let mut noise_count = 0_usize ;
227+
228+ for label in predictions. iterator ( 0 ) {
229+ let value = * label;
230+ if self . algorithm_name == ClusteringAlgorithmName :: DBSCAN && value == CLUSTER :: zero ( ) {
231+ noise_count += 1 ;
232+ } else {
233+ unique_clusters. insert ( value) ;
234+ }
235+ }
236+
237+ self . baseline = Some ( ClusterBaseline :: new ( unique_clusters. len ( ) , noise_count) ) ;
238+ }
239+
195240 fn display_row ( & self ) -> Vec < String > {
196241 let ( homogeneity, completeness, v_measure) = if let Some ( scores) = & self . metrics {
197242 let format_score = |s : Option < f64 > | match s {
@@ -207,8 +252,19 @@ where
207252 ( "-" . to_string ( ) , "-" . to_string ( ) , "-" . to_string ( ) )
208253 } ;
209254
255+ let ( clusters, noise) = if let Some ( baseline) = & self . baseline {
256+ (
257+ baseline. cluster_count . to_string ( ) ,
258+ baseline. noise_count . to_string ( ) ,
259+ )
260+ } else {
261+ ( "-" . to_string ( ) , "-" . to_string ( ) )
262+ } ;
263+
210264 vec ! [
211265 self . algorithm_name. to_string( ) ,
266+ clusters,
267+ noise,
212268 homogeneity,
213269 completeness,
214270 v_measure,
0 commit comments