Skip to content

Commit f87bb56

Browse files
authored
[BUG] Device inconstency in MQF2DistributionLoss raising: RuntimeError: Expected all tensors to be on the same device (#1916)
Fixes #1182 In the current implementation, the `picnn` is initialized during class construction, and the device that it defaults to isn't being updated when the model is moved to another device. * Added device movement method `to()` to ensure that `picnn` is moved along with the loss function. * Added automatic device sync in `map_x_to_distribution` to ensure that `picnn` is on the same device as the input tensor. Also added the tests mocking the accelerators on a high level to test the synchronization of devices within this class.
1 parent dfdf93e commit f87bb56

File tree

2 files changed

+221
-0
lines changed

2 files changed

+221
-0
lines changed

pytorch_forecasting/metrics/distributions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def __init__(
394394
self.prediction_length = prediction_length
395395
self.es_num_samples = es_num_samples
396396
self.beta = beta
397+
self._transformation = None
397398

398399
# define picnn
399400
convexnet = PICNN(
@@ -421,11 +422,19 @@ def __init__(
421422

422423
self.picnn = SequentialNet(networks)
423424

425+
def to(self, device):
426+
"""Move the loss and its components to the specified device."""
427+
self.picnn = self.picnn.to(device)
428+
return super().to(device)
429+
424430
@property
425431
def is_energy_score(self) -> bool:
426432
return self.es_num_samples is not None
427433

428434
def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Distribution:
435+
if hasattr(self.picnn, "to"):
436+
self.picnn = self.picnn.to(x.device)
437+
429438
distr = self.distribution_class(
430439
picnn=self.picnn,
431440
hidden_state=x[..., :-2],

tests/test_metrics.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from functools import wraps
12
import itertools
3+
from unittest.mock import MagicMock, PropertyMock, patch
24

35
import pytest
46
import torch
57
from torch.nn.utils import rnn
68

9+
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
10+
from pytorch_forecasting.data import NaNLabelEncoder
711
from pytorch_forecasting.data.encoders import TorchNormalizer
812
from pytorch_forecasting.metrics import (
913
MAE,
@@ -19,6 +23,7 @@
1923
AggregationMetric,
2024
CompositeMetric,
2125
)
26+
from pytorch_forecasting.utils._dependencies import _get_installed_packages
2227

2328

2429
def test_composite_metric():
@@ -306,6 +311,213 @@ def test_ImplicitQuantileNetworkDistributionLoss():
306311
assert point_prediction.ndim == loss.to_prediction(pred, n_samples=100).ndim
307312

308313

314+
@pytest.fixture
315+
def sample_dataset():
316+
"""Fixture to create a sample TimeSeriesDataSet for testing."""
317+
import numpy as np
318+
import pandas as pd
319+
320+
rows = 15
321+
df = pd.DataFrame(
322+
{
323+
"time": pd.date_range("2025-01-01", periods=rows, freq="h"),
324+
"label": ["test"] * rows,
325+
"var1": np.random.randn(rows).cumsum(),
326+
"var2": np.random.randn(rows).cumsum(),
327+
}
328+
)
329+
df = df.sort_values("time").reset_index(drop=True)
330+
df["past_var1"] = df["var1"].shift(-1)
331+
df.dropna(subset=["past_var1"], inplace=True)
332+
df["time_idx"] = range(len(df))
333+
return TimeSeriesDataSet(
334+
df,
335+
time_idx="time_idx",
336+
target="past_var1",
337+
group_ids=["label"],
338+
static_categoricals=["label"],
339+
time_varying_known_reals=["var1", "var2"],
340+
time_varying_unknown_reals=["past_var1"],
341+
max_encoder_length=5,
342+
max_prediction_length=2,
343+
categorical_encoders={"label": NaNLabelEncoder(add_nan=False)},
344+
)
345+
346+
347+
@pytest.fixture(params=["cuda", "cpu"])
348+
def mock_device(request):
349+
"""Fixture to create a mock device for testing."""
350+
# Create a torch.device object
351+
device_str = f"{request.param}:0" if request.param == "cuda" else "cpu"
352+
mock_device = torch.device(device_str)
353+
354+
orig_tensor = torch.tensor
355+
orig_empty = torch.empty
356+
357+
@wraps(orig_tensor)
358+
def mock_tensor(data, *args, **kwargs):
359+
# Force device to CPU
360+
kwargs["device"] = "cpu"
361+
tensor = orig_tensor(data, *args, **kwargs)
362+
tensor.device = mock_device
363+
return tensor
364+
365+
@wraps(orig_empty)
366+
def mock_empty(*args, **kwargs):
367+
kwargs["device"] = "cpu"
368+
tensor = orig_empty(*args, **kwargs)
369+
tensor.device = mock_device
370+
return tensor
371+
372+
if request.param == "cuda":
373+
mock_properties = type(
374+
"CudaDeviceProperties",
375+
(),
376+
{
377+
"major": 8,
378+
"minor": 0,
379+
"name": "Mocked CUDA Device",
380+
"total_memory": 8 * 1024 * 1024 * 1024,
381+
},
382+
)()
383+
384+
with (
385+
patch("torch.cuda.is_available", return_value=True),
386+
patch("torch.cuda._lazy_init", return_value=None),
387+
patch("torch.cuda.device_count", return_value=1),
388+
patch("torch.cuda.get_device_properties", return_value=mock_properties),
389+
patch("torch.cuda.get_device_capability", return_value=(8, 0)),
390+
patch("torch.cuda.set_device", return_value=None),
391+
patch("torch.empty", new=mock_empty),
392+
patch("torch.tensor", new=mock_tensor),
393+
patch(
394+
"torch.Tensor.to",
395+
new=lambda self, device, *args, **kwargs: self.clone()
396+
if isinstance(device, (str, torch.device))
397+
and str(device).startswith("cuda")
398+
else self,
399+
),
400+
patch(
401+
"torch.Tensor.device",
402+
new_callable=PropertyMock,
403+
return_value=mock_device,
404+
),
405+
patch("torch.Tensor.cuda", new=lambda self, *args, **kwargs: self.clone()),
406+
patch("torch.nn.Module.cuda", new=lambda self, *args, **kwargs: self),
407+
patch("torch.nn.Module.to", new=lambda self, device, *args, **kwargs: self),
408+
):
409+
yield "cuda"
410+
else:
411+
yield "cpu"
412+
413+
414+
@pytest.mark.skipif(
415+
"cpflows" not in _get_installed_packages(),
416+
reason="cpflows is not installed, skipping MQF2DistributionLoss tests",
417+
)
418+
def test_MQF2DistributionLoss_device_handling(mock_device):
419+
from pytorch_forecasting.metrics import MQF2DistributionLoss
420+
421+
loss = MQF2DistributionLoss(prediction_length=2)
422+
423+
assert next(loss.picnn.parameters()).device.type == mock_device
424+
425+
if mock_device == "cuda":
426+
loss.cuda()
427+
assert next(loss.picnn.parameters()).device.type == "cuda"
428+
elif mock_device == "cpu":
429+
loss.cpu()
430+
assert next(loss.picnn.parameters()).device.type == "cpu"
431+
loss.to(mock_device)
432+
assert next(loss.picnn.parameters()).device.type == mock_device
433+
434+
435+
device_params = [
436+
pytest.param(
437+
"cuda",
438+
marks=pytest.mark.skipif(
439+
not torch.cuda.is_available(), reason="CUDA is not available"
440+
),
441+
),
442+
"cpu",
443+
]
444+
445+
446+
@pytest.mark.skipif(
447+
"cpflows" not in _get_installed_packages(),
448+
reason="cpflows is not installed, skipping MQF2DistributionLoss tests",
449+
)
450+
@pytest.mark.parametrize("device", device_params)
451+
def test_MQF2DistributionLoss_full_workflow(sample_dataset, device):
452+
"""
453+
Test the complete workflow from training to prediction with MQF2DistributionLoss.
454+
"""
455+
import lightning.pytorch as pl
456+
457+
from pytorch_forecasting.metrics import MQF2DistributionLoss
458+
459+
model = TemporalFusionTransformer.from_dataset(
460+
sample_dataset, loss=MQF2DistributionLoss(prediction_length=2)
461+
)
462+
463+
trainer = pl.Trainer(
464+
max_epochs=1,
465+
accelerator=device,
466+
devices="auto",
467+
gradient_clip_val=0.1,
468+
limit_train_batches=30,
469+
limit_val_batches=3,
470+
)
471+
dataloader = sample_dataset.to_dataloader(train=True, batch_size=4, num_workers=0)
472+
473+
trainer.fit(model, dataloader)
474+
475+
raw_predictions = model.predict(
476+
dataloader,
477+
mode="raw",
478+
return_x=True,
479+
trainer_kwargs=dict(accelerator=device, devices="auto", logger=False),
480+
)
481+
# Verify predictions are on correct device
482+
pred_device = raw_predictions.output["prediction"].device.type
483+
target_device = raw_predictions.x["encoder_target"].device.type
484+
assert pred_device == device
485+
assert target_device == device
486+
try:
487+
model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0)
488+
plot_success = True
489+
except RuntimeError as e:
490+
if "device" in str(e).lower() or "expected" in str(e).lower():
491+
plot_success = False
492+
pytest.fail(f"Device mismatch error during plotting: {e}")
493+
else:
494+
raise e
495+
assert plot_success, "Plotting failed due to device mismatch"
496+
497+
498+
@pytest.mark.skipif(
499+
"cpflows" not in _get_installed_packages(),
500+
reason="cpflows is not installed, skipping MQF2DistributionLoss tests",
501+
)
502+
def test_MQF2DistributionLoss_device_synchronization(mock_device, sample_dataset):
503+
"""Test that MQF2DistributionLoss components are synchronized with the device."""
504+
from pytorch_forecasting.metrics import MQF2DistributionLoss
505+
506+
model = TemporalFusionTransformer.from_dataset(
507+
sample_dataset, loss=MQF2DistributionLoss(prediction_length=2)
508+
)
509+
fake_prediction = torch.randn(4, 2, 8)
510+
511+
if mock_device == "cuda":
512+
fake_prediction = fake_prediction.cuda()
513+
model.loss.map_x_to_distribution(fake_prediction)
514+
assert next(model.loss.picnn.parameters()).device.type == "cuda"
515+
if mock_device == "cpu":
516+
fake_prediction = fake_prediction.cpu()
517+
model.loss.map_x_to_distribution(fake_prediction)
518+
assert next(model.loss.picnn.parameters()).device.type == "cpu"
519+
520+
309521
def test_CrossEntropyLoss():
310522
batch_size = 3
311523
n_timesteps = 5

0 commit comments

Comments
 (0)