Skip to content

Commit 66c9a9e

Browse files
committed
Merge branch 'master' of github.com:nimarb/pytorch_influence_functions
2 parents 4d8547e + fc88319 commit 66c9a9e

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

pytorch_influence_functions/calc_influence_function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def calc_img_wise(config, model, train_loader, test_loader):
466466

467467
start_time = time.time()
468468
influence, harmful, helpful, _ = calc_influence_single(
469-
model, train_loader, test_loader, test_id_num=i, gpu=0,
469+
model, train_loader, test_loader, test_id_num=i, gpu=config['gpu'],
470470
recursion_depth=config['recursion_depth'], r=config['r_averaging'])
471471
end_time = time.time()
472472

pytorch_influence_functions/influence_function.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def s_test(z_test, t_test, model, z_loader, gpu=-1, damp=0.01, scale=25.0,
4343
x, t = x.cuda(), t.cuda()
4444
y = model(x)
4545
loss = calc_loss(y, t)
46-
hv = hvp(loss, list(model.parameters()), h_estimate)
46+
params = [ p for p in model.parameters() if p.requires_grad ]
47+
hv = hvp(loss, params, h_estimate)
4748
# Recursively caclulate h_estimate
4849
h_estimate = [
4950
_v + (1 - damp) * _h_e - _hv / scale
@@ -93,7 +94,8 @@ def grad_z(z, t, model, gpu=-1):
9394
y = model(z)
9495
loss = calc_loss(y, t)
9596
# Compute sum of gradients from model parameters to loss
96-
return list(grad(loss, list(model.parameters()), create_graph=True))
97+
params = [ p for p in model.parameters() if p.requires_grad ]
98+
return list(grad(loss, params, create_graph=True))
9799

98100

99101
def hvp(y, w, v):

0 commit comments

Comments
 (0)