RMSE problem
- RMSE is supposed to be like this.
rmse = torch.sqrt(torch.nn.MSELoss(logits, y.float()))
- But,
torch.sqrt(0) return NaN
Solution
- So, you should add any extremely small number aka
epsilon
Here's the solution
eps = 1e-6
rmse = torch.sqrt(torch.nn.MSELoss(logits, y.float()) + eps)