Skip to content

Commit 5248929

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Ensure eval diff and mask tensor are on the same device (#1542)
Summary: Pull Request resolved: #1542 Failure (prod pkg): f718895429 Reviewed By: vivekmig Differential Revision: D72682579 fbshipit-source-id: 78045d6c58af10fae118204fcc4433d65d986eff
1 parent 20082a1 commit 5248929

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

captum/attr/_core/shapley_value.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,19 @@ def attribute(
397397
if show_progress:
398398
attr_progress.update()
399399
if agg_output_mode:
400-
eval_diff = modified_eval - prev_results
400+
eval_diff = (modified_eval - prev_results).to(
401+
inputs_tuple[0].device
402+
)
401403
prev_results = modified_eval
402404
else:
403405
# when perturb_per_eval > 1, every num_examples stands for
404406
# one perturb. Since the perturbs are from a consecutive
405407
# perumuation, each diff of a perturb is its eval minus
406408
# the eval of the previous perturb
407409
all_eval = torch.cat((prev_results, modified_eval), dim=0)
408-
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
410+
eval_diff = (
411+
all_eval[num_examples:] - all_eval[:-num_examples]
412+
).to(inputs_tuple[0].device)
409413
prev_results = all_eval[-num_examples:]
410414

411415
for j in range(len(total_attrib)):
@@ -689,7 +693,7 @@ def _evalFutToPrevResultsTuple(
689693
agg_output_mode,
690694
) = prev_results_tuple
691695
if agg_output_mode:
692-
eval_diff = modified_eval - prev_results
696+
eval_diff = (modified_eval - prev_results).to(inputs_tuple[0].device)
693697
prev_results = modified_eval
694698
else:
695699
# when perturb_per_eval > 1, every num_examples stands for
@@ -698,7 +702,9 @@ def _evalFutToPrevResultsTuple(
698702
# the eval of the previous perturb
699703

700704
all_eval = torch.cat((prev_results, modified_eval), dim=0)
701-
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
705+
eval_diff = (all_eval[num_examples:] - all_eval[:-num_examples]).to(
706+
inputs_tuple[0].device
707+
)
702708
prev_results = all_eval[-num_examples:]
703709

704710
for j in range(len(total_attrib)):
@@ -799,7 +805,10 @@ def _perturbation_generator(
799805
)
800806
current_tensors_list.append(current_tensors)
801807
current_mask_list.append(
802-
tuple(mask == feature_permutation[i] for mask in input_masks)
808+
tuple(
809+
(mask == feature_permutation[i]).to(inputs[0].device)
810+
for mask in input_masks
811+
)
803812
)
804813
if len(current_tensors_list) == perturbations_per_eval:
805814
combined_inputs = tuple(

0 commit comments

Comments
 (0)