Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_forecasting/layers/_output/_flatten_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
x = x.permute(0, 2, 1)

if self.n_quantiles is not None:
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
batch_size = x.shape[0]
x = x.reshape(batch_size, -1, self.n_quantiles)
return x
11 changes: 3 additions & 8 deletions pytorch_forecasting/models/dlinear/_dlinear_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,17 @@ def _reshape_output(self, output: torch.Tensor) -> torch.Tensor:
Returns
-------
output: torch.Tensor
Reshaped tensor (batch_size, prediction_length, n_features, n_quantiles)
Reshaped tensor (batch_size, prediction_length, n_quantiles)
or (batch_size, prediction_length, n_features) if n_quantiles is None.
"""
if self.n_quantiles is not None:
batch_size, n_features = output.shape[0], output.shape[1]
batch_size = output.shape[0]
output = output.reshape(
batch_size, n_features, self.prediction_length, self.n_quantiles
batch_size, self.prediction_length, self.n_quantiles
)
output = output.permute(0, 2, 1, 3) # (batch, time, features, quantiles)
else:
output = output.permute(0, 2, 1) # (batch, time, features)

# univariate forecasting
if self.target_dim == 1 and output.shape[-1] == 1:
output = output.squeeze(-1)

return output

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Expand Down
9 changes: 0 additions & 9 deletions pytorch_forecasting/models/timexer/_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,6 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

dec_out = self.head(enc_out)

if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)

return dec_out

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Expand All @@ -330,10 +325,6 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
out = self._forecast(x)
prediction = out[:, : self.prediction_length, :]

# check to see if the output shape is equal to number of targets
if prediction.size(2) != self.target_dim:
prediction = prediction[:, :, : self.target_dim]

if "target_scale" in x:
prediction = self.transform_output(prediction, x["target_scale"])

Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_dlinear_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_quantile_loss_output(sample_dataset):

assert "prediction" in output
pred = output["prediction"]
assert pred.ndim == 4
assert pred.ndim == 3
assert pred.shape[-1] == len(quantiles)
assert pred.shape[1] == metadata["prediction_length"]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_quantile_predictions(basic_metadata):
output = model(sample_input_data)

predictions = output["prediction"]
assert predictions.shape == (batch_size, 8, 1, 3)
assert predictions.shape == (batch_size, 8, 3)


def test_missing_history_target_handling(basic_metadata):
Expand Down
Loading