@@ -397,15 +397,19 @@ def attribute(
397
397
if show_progress :
398
398
attr_progress .update ()
399
399
if agg_output_mode :
400
- eval_diff = modified_eval - prev_results
400
+ eval_diff = (modified_eval - prev_results ).to (
401
+ inputs_tuple [0 ].device
402
+ )
401
403
prev_results = modified_eval
402
404
else :
403
405
# when perturb_per_eval > 1, every num_examples stands for
404
406
# one perturb. Since the perturbs are from a consecutive
405
407
# perumuation, each diff of a perturb is its eval minus
406
408
# the eval of the previous perturb
407
409
all_eval = torch .cat ((prev_results , modified_eval ), dim = 0 )
408
- eval_diff = all_eval [num_examples :] - all_eval [:- num_examples ]
410
+ eval_diff = (
411
+ all_eval [num_examples :] - all_eval [:- num_examples ]
412
+ ).to (inputs_tuple [0 ].device )
409
413
prev_results = all_eval [- num_examples :]
410
414
411
415
for j in range (len (total_attrib )):
@@ -689,7 +693,7 @@ def _evalFutToPrevResultsTuple(
689
693
agg_output_mode ,
690
694
) = prev_results_tuple
691
695
if agg_output_mode :
692
- eval_diff = modified_eval - prev_results
696
+ eval_diff = ( modified_eval - prev_results ). to ( inputs_tuple [ 0 ]. device )
693
697
prev_results = modified_eval
694
698
else :
695
699
# when perturb_per_eval > 1, every num_examples stands for
@@ -698,7 +702,9 @@ def _evalFutToPrevResultsTuple(
698
702
# the eval of the previous perturb
699
703
700
704
all_eval = torch .cat ((prev_results , modified_eval ), dim = 0 )
701
- eval_diff = all_eval [num_examples :] - all_eval [:- num_examples ]
705
+ eval_diff = (all_eval [num_examples :] - all_eval [:- num_examples ]).to (
706
+ inputs_tuple [0 ].device
707
+ )
702
708
prev_results = all_eval [- num_examples :]
703
709
704
710
for j in range (len (total_attrib )):
@@ -799,7 +805,10 @@ def _perturbation_generator(
799
805
)
800
806
current_tensors_list .append (current_tensors )
801
807
current_mask_list .append (
802
- tuple (mask == feature_permutation [i ] for mask in input_masks )
808
+ tuple (
809
+ (mask == feature_permutation [i ]).to (inputs [0 ].device )
810
+ for mask in input_masks
811
+ )
803
812
)
804
813
if len (current_tensors_list ) == perturbations_per_eval :
805
814
combined_inputs = tuple (
0 commit comments