diff --git a/pytorch_influence_functions/calc_influence_function.py b/pytorch_influence_functions/calc_influence_function.py index 8861cd4..8499a82 100644 --- a/pytorch_influence_functions/calc_influence_function.py +++ b/pytorch_influence_functions/calc_influence_function.py @@ -269,7 +269,7 @@ def calc_influence_function(train_dataset_size, grad_z_vecs=None, # There is one grad_z per training data sample ################################### ]) / train_dataset_size - influences.append(tmp_influence) + influences.append(tmp_influence.cpu()) display_progress("Calc. influence function: ", i, train_dataset_size) harmful = np.argsort(influences) @@ -340,7 +340,7 @@ def calc_influence_single(model, train_loader, test_loader, test_id_num, gpu, torch.sum(k * j).data for k, j in zip(grad_z_vec, s_test_vec) ]) / train_dataset_size - influences.append(tmp_influence) + influences.append(tmp_influence.cpu()) display_progress("Calc. influence function: ", i, train_dataset_size) harmful = np.argsort(influences)