From b903cdee0807935fb2c86c121cc12a0b0797cd2a Mon Sep 17 00:00:00 2001 From: Teo Date: Fri, 4 Aug 2023 09:27:44 +0200 Subject: [PATCH] Bugfix: handle case where 'losses' is a list of tensors --- pytorch_forecasting/metrics/base_metrics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_forecasting/metrics/base_metrics.py b/pytorch_forecasting/metrics/base_metrics.py index 041d24044..d757be67b 100644 --- a/pytorch_forecasting/metrics/base_metrics.py +++ b/pytorch_forecasting/metrics/base_metrics.py @@ -822,6 +822,9 @@ def mask_losses(self, losses: torch.Tensor, lengths: torch.Tensor, reduction: st """ if reduction is None: reduction = self.reduction + + if isinstance(losses, list): losses = losses[0] + if losses.ndim > 0: # mask loss mask = torch.arange(losses.size(1), device=losses.device).unsqueeze(0) >= lengths.unsqueeze(-1)