diff --git a/pytorch_influence_functions/influence_function.py b/pytorch_influence_functions/influence_function.py index 7b08e37..f7230e4 100644 --- a/pytorch_influence_functions/influence_function.py +++ b/pytorch_influence_functions/influence_function.py @@ -67,9 +67,10 @@ def calc_loss(y, t): # if dim == [0, 1, 3] then dim=0; else dim=1 #################### # y = torch.nn.functional.log_softmax(y, dim=0) - y = torch.nn.functional.log_softmax(y) - loss = torch.nn.functional.nll_loss( - y, t, weight=None, reduction='mean') + # y = torch.nn.functional.log_softmax(y) + # loss = torch.nn.functional.nll_loss( + # y, t, weight=None, reduction='mean') + loss = torch.nn.functional.cross_entropy(y, t, weight=None, reduction="mean") return loss