Skip to content

Commit 85d4d69

Browse files
authored
Fix filter by instances on Insights (#92)
1 parent 17ab76a commit 85d4d69

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

captum/insights/api.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _predictions_matches_labels(
170170
return labels == predicted_label
171171

172172
def _should_keep_prediction(
173-
self, predicted_scores: List[OutputScore], actual_label: str
173+
self, predicted_scores: List[OutputScore], actual_label: OutputScore
174174
) -> bool:
175175
# filter by class
176176
if len(self._config.classes) != 0:
@@ -180,13 +180,14 @@ def _should_keep_prediction(
180180
return False
181181

182182
# filter by accuracy
183+
label_name = actual_label.label
183184
if self._config.prediction == "all":
184185
pass
185186
elif self._config.prediction == "correct":
186-
if not self._predictions_matches_labels(predicted_scores, actual_label):
187+
if not self._predictions_matches_labels(predicted_scores, label_name):
187188
return False
188189
elif self._config.prediction == "incorrect":
189-
if self._predictions_matches_labels(predicted_scores, actual_label):
190+
if self._predictions_matches_labels(predicted_scores, label_name):
190191
return False
191192
else:
192193
raise Exception(f"Invalid prediction config: {self._config.prediction}")
@@ -238,16 +239,16 @@ def _calculate_vis_output(
238239
predicted = predicted.cpu().squeeze(0)
239240

240241
if label is not None and len(label) > 0:
241-
actual_label = OutputScore(
242-
score=0, index=label[0], label=self.classes[label[0]]
242+
actual_label_output = OutputScore(
243+
score=100, index=label[0], label=self.classes[label[0]]
243244
)
244245
else:
245-
actual_label = None
246+
actual_label_output = None
246247

247248
predicted_scores = self._get_labels_from_scores(scores, predicted)
248249

249250
# Filter based on UI configuration
250-
if not self._should_keep_prediction(predicted_scores, actual_label):
251+
if not self._should_keep_prediction(predicted_scores, actual_label_output):
251252
return None
252253

253254
baselines = [tuple(b) for b in baselines]
@@ -277,9 +278,9 @@ def _calculate_vis_output(
277278

278279
return VisualizationOutput(
279280
feature_outputs=features_per_input,
280-
actual=actual_label,
281+
actual=actual_label_output,
281282
predicted=predicted_scores,
282-
active_index=target if target is not None else actual_label.index,
283+
active_index=target if target is not None else actual_label_output.index,
283284
)
284285

285286
def _get_outputs(self) -> List[VisualizationOutput]:

0 commit comments

Comments
 (0)