From 8c7d277a5c8c91b31e04f7cccda95377b1136dff Mon Sep 17 00:00:00 2001 From: Zayd Hammoudeh Date: Sat, 23 Jan 2021 06:35:27 -0800 Subject: [PATCH] Fix #23: Improve calc_loss's numerical stability using cross entropy. --- pytorch_influence_functions/influence_function.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_influence_functions/influence_function.py b/pytorch_influence_functions/influence_function.py index 7b08e37..f7230e4 100644 --- a/pytorch_influence_functions/influence_function.py +++ b/pytorch_influence_functions/influence_function.py @@ -67,9 +67,10 @@ def calc_loss(y, t): # if dim == [0, 1, 3] then dim=0; else dim=1 #################### # y = torch.nn.functional.log_softmax(y, dim=0) - y = torch.nn.functional.log_softmax(y) - loss = torch.nn.functional.nll_loss( - y, t, weight=None, reduction='mean') + # y = torch.nn.functional.log_softmax(y) + # loss = torch.nn.functional.nll_loss( + # y, t, weight=None, reduction='mean') + loss = torch.nn.functional.cross_entropy(y, t, weight=None, reduction="mean") return loss