@@ -170,7 +170,7 @@ def _predictions_matches_labels(
170
170
return labels == predicted_label
171
171
172
172
def _should_keep_prediction (
173
- self , predicted_scores : List [OutputScore ], actual_label : str
173
+ self , predicted_scores : List [OutputScore ], actual_label : OutputScore
174
174
) -> bool :
175
175
# filter by class
176
176
if len (self ._config .classes ) != 0 :
@@ -180,13 +180,14 @@ def _should_keep_prediction(
180
180
return False
181
181
182
182
# filter by accuracy
183
+ label_name = actual_label .label
183
184
if self ._config .prediction == "all" :
184
185
pass
185
186
elif self ._config .prediction == "correct" :
186
- if not self ._predictions_matches_labels (predicted_scores , actual_label ):
187
+ if not self ._predictions_matches_labels (predicted_scores , label_name ):
187
188
return False
188
189
elif self ._config .prediction == "incorrect" :
189
- if self ._predictions_matches_labels (predicted_scores , actual_label ):
190
+ if self ._predictions_matches_labels (predicted_scores , label_name ):
190
191
return False
191
192
else :
192
193
raise Exception (f"Invalid prediction config: { self ._config .prediction } " )
@@ -238,16 +239,16 @@ def _calculate_vis_output(
238
239
predicted = predicted .cpu ().squeeze (0 )
239
240
240
241
if label is not None and len (label ) > 0 :
241
- actual_label = OutputScore (
242
- score = 0 , index = label [0 ], label = self .classes [label [0 ]]
242
+ actual_label_output = OutputScore (
243
+ score = 100 , index = label [0 ], label = self .classes [label [0 ]]
243
244
)
244
245
else :
245
- actual_label = None
246
+ actual_label_output = None
246
247
247
248
predicted_scores = self ._get_labels_from_scores (scores , predicted )
248
249
249
250
# Filter based on UI configuration
250
- if not self ._should_keep_prediction (predicted_scores , actual_label ):
251
+ if not self ._should_keep_prediction (predicted_scores , actual_label_output ):
251
252
return None
252
253
253
254
baselines = [tuple (b ) for b in baselines ]
@@ -277,9 +278,9 @@ def _calculate_vis_output(
277
278
278
279
return VisualizationOutput (
279
280
feature_outputs = features_per_input ,
280
- actual = actual_label ,
281
+ actual = actual_label_output ,
281
282
predicted = predicted_scores ,
282
- active_index = target if target is not None else actual_label .index ,
283
+ active_index = target if target is not None else actual_label_output .index ,
283
284
)
284
285
285
286
def _get_outputs (self ) -> List [VisualizationOutput ]:
0 commit comments