Skip to content

Commit 106f82f

Browse files
authored
[ENH] Add missing test for forward output of TimeXer as proposed in #1936 (#1951)
#### What does this implement/fix? Explain your changes. Adds the test for the output shapes of the forward method of `TimeXer`. The test checks the output contract for the model, which is being implemented in #1936. Currently skipped under `test_timexer.py` due to obvious causes of failure i.e current version of model on main is still not updated to the required output contract.
1 parent 3821c0b commit 106f82f

File tree

1 file changed

+122
-12
lines changed

1 file changed

+122
-12
lines changed

tests/test_models/test_timexer.py

Lines changed: 122 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,27 @@
1717
from pytorch_forecasting.models import TimeXer
1818

1919

20+
def _expected_fwd_shape(batch_size, prediction_length, loss):
21+
"""
22+
Return the expected output shape for the forward pass of the model.
23+
"""
24+
25+
if isinstance(loss, QuantileLoss):
26+
n_quantiles = len(loss.quantiles)
27+
return (batch_size, prediction_length, n_quantiles)
28+
elif isinstance(loss, MultiLoss):
29+
shapes = []
30+
for single_loss in loss.losses:
31+
if isinstance(single_loss, QuantileLoss):
32+
n_quantiles = len(single_loss.quantiles)
33+
shapes.append((batch_size, prediction_length, n_quantiles))
34+
else:
35+
shapes.append((batch_size, prediction_length, 1))
36+
return shapes
37+
else:
38+
return (batch_size, prediction_length, 1)
39+
40+
2041
def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs):
2142
"""
2243
Integration test for the TimeXer model.
@@ -96,7 +117,6 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
96117
net = TimeXer.load_from_checkpoint(
97118
trainer.checkpoint_callback.best_model_path,
98119
)
99-
100120
predictions = net.predict(
101121
val_dataloader,
102122
return_index=True,
@@ -117,17 +137,6 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
117137
f"but got {predictions.output.shape}"
118138
)
119139

120-
# raw prediction if debugging the model
121-
122-
net.predict(
123-
val_dataloader,
124-
return_index=True,
125-
return_x=True,
126-
fast_dev_run=True,
127-
mode="raw",
128-
trainer_kwargs=trainer_kwargs,
129-
)
130-
131140
finally:
132141
# remove the temporary directory created for the test
133142
shutil.rmtree(tmp_path, ignore_errors=True)
@@ -471,3 +480,104 @@ def test_with_exogenous_variables(tmp_path):
471480

472481
finally:
473482
shutil.rmtree(tmp_path, ignore_errors=True)
483+
484+
485+
@pytest.mark.skipif(
486+
True, reason="Skipping due to incompatibility with current model outputs."
487+
) # noqa: E501
488+
def test_model_forward_output(dataloaders_with_covariates):
489+
"""
490+
Test the model's forward output shapes.
491+
This test checks that the model's forward pass returns outputs
492+
of expected shapes based on the loss function used.
493+
Args:
494+
dataloaders_with_covariates: The dataloaders to use for training and validation
495+
"""
496+
497+
train_dataloader = dataloaders_with_covariates["train"]
498+
val_dataloader = dataloaders_with_covariates["val"]
499+
500+
dataset = train_dataloader.dataset
501+
batch = next(iter(val_dataloader))
502+
x, y = batch
503+
504+
batch_size = x["encoder_cont"].shape[0]
505+
prediction_length = dataset.max_prediction_length
506+
507+
loss = MAE()
508+
model = TimeXer.from_dataset(
509+
dataset,
510+
hidden_size=16,
511+
n_heads=2,
512+
e_layers=1,
513+
patch_length=2,
514+
dropout=0.1,
515+
loss=loss,
516+
)
517+
518+
with torch.no_grad():
519+
output = model(x)
520+
521+
prediction = output["prediction"]
522+
expected_shape = _expected_fwd_shape(
523+
batch_size=batch_size,
524+
prediction_length=prediction_length,
525+
loss=loss,
526+
)
527+
528+
assert (
529+
prediction.shape == expected_shape
530+
), f"Expected output shape {expected_shape}, but got {prediction.shape}"
531+
532+
quantile_loss = QuantileLoss(quantiles=[0.1, 0.5, 0.9])
533+
model_quantile = TimeXer.from_dataset(
534+
dataset,
535+
hidden_size=16,
536+
n_heads=2,
537+
e_layers=1,
538+
patch_length=2,
539+
dropout=0.1,
540+
loss=quantile_loss,
541+
)
542+
543+
with torch.no_grad():
544+
output_quantile = model_quantile(x)
545+
prediction_quantile = output_quantile["prediction"]
546+
expected_shape_quantile = _expected_fwd_shape(
547+
batch_size=batch_size,
548+
prediction_length=prediction_length,
549+
loss=quantile_loss,
550+
)
551+
assert prediction_quantile.shape == expected_shape_quantile, (
552+
f"Expected output shape {expected_shape_quantile}, but got {prediction_quantile.shape}" # noqa: E501
553+
)
554+
555+
multi_loss = MultiLoss([MAE(), MAE()])
556+
model_multi = TimeXer.from_dataset(
557+
dataset,
558+
hidden_size=16,
559+
n_heads=2,
560+
e_layers=1,
561+
d_ff=32,
562+
patch_length=2,
563+
dropout=0.1,
564+
loss=multi_loss,
565+
)
566+
567+
with torch.no_grad():
568+
output_multi = model_multi(x)
569+
570+
prediction_multi = output_multi["prediction"]
571+
expected_shapes_multi = _expected_fwd_shape(
572+
batch_size, prediction_length, multi_loss
573+
)
574+
575+
assert isinstance(prediction_multi, list)
576+
assert len(prediction_multi) == len(expected_shapes_multi)
577+
578+
for i, (pred_tensor, expected_shape) in enumerate(
579+
zip(prediction_multi, expected_shapes_multi)
580+
): # noqa: E501
581+
assert (
582+
pred_tensor.shape == expected_shape
583+
), f"MultiLoss target {i}: Expected {expected_shape}, got {pred_tensor.shape}" # noqa: E501

0 commit comments

Comments
 (0)