From 60ca7f6c2534aee1f2f2b80be10dd087ae092aac Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 6 Sep 2025 01:58:31 +0530 Subject: [PATCH 1/5] standardise dlinear --- pytorch_forecasting/models/dlinear/_dlinear_v2.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_forecasting/models/dlinear/_dlinear_v2.py b/pytorch_forecasting/models/dlinear/_dlinear_v2.py index f4207a66e..6f940728c 100644 --- a/pytorch_forecasting/models/dlinear/_dlinear_v2.py +++ b/pytorch_forecasting/models/dlinear/_dlinear_v2.py @@ -243,18 +243,13 @@ def _reshape_output(self, output: torch.Tensor) -> torch.Tensor: or (batch_size, prediction_length, n_features) if n_quantiles is None. """ if self.n_quantiles is not None: - batch_size, n_features = output.shape[0], output.shape[1] + batch_size = output.shape[0] output = output.reshape( - batch_size, n_features, self.prediction_length, self.n_quantiles + batch_size, self.prediction_length, self.n_quantiles ) - output = output.permute(0, 2, 1, 3) # (batch, time, features, quantiles) else: output = output.permute(0, 2, 1) # (batch, time, features) - # univariate forecasting - if self.target_dim == 1 and output.shape[-1] == 1: - output = output.squeeze(-1) - return output def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: From 2ed40608dc8a99839f85748c9238769308352c5d Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 6 Sep 2025 01:59:16 +0530 Subject: [PATCH 2/5] update docstring --- pytorch_forecasting/models/dlinear/_dlinear_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/dlinear/_dlinear_v2.py b/pytorch_forecasting/models/dlinear/_dlinear_v2.py index 6f940728c..14271c11a 100644 --- a/pytorch_forecasting/models/dlinear/_dlinear_v2.py +++ b/pytorch_forecasting/models/dlinear/_dlinear_v2.py @@ -239,7 +239,7 @@ def _reshape_output(self, output: torch.Tensor) -> torch.Tensor: Returns ------- output: torch.Tensor - Reshaped tensor (batch_size, prediction_length, n_features, n_quantiles) + Reshaped tensor (batch_size, prediction_length, n_quantiles) or (batch_size, prediction_length, n_features) if n_quantiles is None. """ if self.n_quantiles is not None: From 29825bff66baaa8c0a64f58d0183f00132809661 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 6 Sep 2025 02:19:38 +0530 Subject: [PATCH 3/5] update tests --- tests/test_models/test_dlinear_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/test_dlinear_v2.py b/tests/test_models/test_dlinear_v2.py index efe2a0552..f46c88d5f 100644 --- a/tests/test_models/test_dlinear_v2.py +++ b/tests/test_models/test_dlinear_v2.py @@ -124,7 +124,7 @@ def test_quantile_loss_output(sample_dataset): assert "prediction" in output pred = output["prediction"] - assert pred.ndim == 4 + assert pred.ndim == 3 assert pred.shape[-1] == len(quantiles) assert pred.shape[1] == metadata["prediction_length"] From bb08f9e65b55af580c3898563d02cc5f85b4db26 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 6 Sep 2025 19:35:44 +0530 Subject: [PATCH 4/5] update timexer --- pytorch_forecasting/layers/_output/_flatten_head.py | 5 +++-- pytorch_forecasting/models/timexer/_timexer_v2.py | 9 --------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pytorch_forecasting/layers/_output/_flatten_head.py b/pytorch_forecasting/layers/_output/_flatten_head.py index 71823b162..ba2aa0eae 100644 --- a/pytorch_forecasting/layers/_output/_flatten_head.py +++ b/pytorch_forecasting/layers/_output/_flatten_head.py @@ -38,8 +38,9 @@ def forward(self, x): x = self.flatten(x) x = self.linear(x) x = self.dropout(x) + x = x.permute(0, 2, 1) if self.n_quantiles is not None: - batch_size, n_vars = x.shape[0], x.shape[1] - x = x.reshape(batch_size, n_vars, -1, self.n_quantiles) + batch_size = x.shape[0] + x = x.reshape(batch_size, -1, self.n_quantiles) return x diff --git a/pytorch_forecasting/models/timexer/_timexer_v2.py b/pytorch_forecasting/models/timexer/_timexer_v2.py index 4a401c497..7e31ef978 100644 --- a/pytorch_forecasting/models/timexer/_timexer_v2.py +++ b/pytorch_forecasting/models/timexer/_timexer_v2.py @@ -311,11 +311,6 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: dec_out = self.head(enc_out) - if self.n_quantiles is not None: - dec_out = dec_out.permute(0, 2, 1, 3) - else: - dec_out = dec_out.permute(0, 2, 1) - return dec_out def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: @@ -330,10 +325,6 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: out = self._forecast(x) prediction = out[:, : self.prediction_length, :] - # check to see if the output shape is equal to number of targets - if prediction.size(2) != self.target_dim: - prediction = prediction[:, :, : self.target_dim] - if "target_scale" in x: prediction = self.transform_output(prediction, x["target_scale"]) From f7aa7c694502072a71424a62135c0de5ad8243fa Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 6 Sep 2025 19:44:56 +0530 Subject: [PATCH 5/5] update timxer tests --- tests/test_models/test_timexer_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/test_timexer_v2.py b/tests/test_models/test_timexer_v2.py index 3285bd0e7..33cec25b0 100644 --- a/tests/test_models/test_timexer_v2.py +++ b/tests/test_models/test_timexer_v2.py @@ -297,7 +297,7 @@ def test_quantile_predictions(basic_metadata): output = model(sample_input_data) predictions = output["prediction"] - assert predictions.shape == (batch_size, 8, 1, 3) + assert predictions.shape == (batch_size, 8, 3) def test_missing_history_target_handling(basic_metadata):