|
| 1 | +from functools import wraps |
1 | 2 | import itertools
|
| 3 | +from unittest.mock import MagicMock, PropertyMock, patch |
2 | 4 |
|
3 | 5 | import pytest
|
4 | 6 | import torch
|
5 | 7 | from torch.nn.utils import rnn
|
6 | 8 |
|
| 9 | +from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet |
| 10 | +from pytorch_forecasting.data import NaNLabelEncoder |
7 | 11 | from pytorch_forecasting.data.encoders import TorchNormalizer
|
8 | 12 | from pytorch_forecasting.metrics import (
|
9 | 13 | MAE,
|
|
19 | 23 | AggregationMetric,
|
20 | 24 | CompositeMetric,
|
21 | 25 | )
|
| 26 | +from pytorch_forecasting.utils._dependencies import _get_installed_packages |
22 | 27 |
|
23 | 28 |
|
24 | 29 | def test_composite_metric():
|
@@ -306,6 +311,213 @@ def test_ImplicitQuantileNetworkDistributionLoss():
|
306 | 311 | assert point_prediction.ndim == loss.to_prediction(pred, n_samples=100).ndim
|
307 | 312 |
|
308 | 313 |
|
| 314 | +@pytest.fixture |
| 315 | +def sample_dataset(): |
| 316 | + """Fixture to create a sample TimeSeriesDataSet for testing.""" |
| 317 | + import numpy as np |
| 318 | + import pandas as pd |
| 319 | + |
| 320 | + rows = 15 |
| 321 | + df = pd.DataFrame( |
| 322 | + { |
| 323 | + "time": pd.date_range("2025-01-01", periods=rows, freq="h"), |
| 324 | + "label": ["test"] * rows, |
| 325 | + "var1": np.random.randn(rows).cumsum(), |
| 326 | + "var2": np.random.randn(rows).cumsum(), |
| 327 | + } |
| 328 | + ) |
| 329 | + df = df.sort_values("time").reset_index(drop=True) |
| 330 | + df["past_var1"] = df["var1"].shift(-1) |
| 331 | + df.dropna(subset=["past_var1"], inplace=True) |
| 332 | + df["time_idx"] = range(len(df)) |
| 333 | + return TimeSeriesDataSet( |
| 334 | + df, |
| 335 | + time_idx="time_idx", |
| 336 | + target="past_var1", |
| 337 | + group_ids=["label"], |
| 338 | + static_categoricals=["label"], |
| 339 | + time_varying_known_reals=["var1", "var2"], |
| 340 | + time_varying_unknown_reals=["past_var1"], |
| 341 | + max_encoder_length=5, |
| 342 | + max_prediction_length=2, |
| 343 | + categorical_encoders={"label": NaNLabelEncoder(add_nan=False)}, |
| 344 | + ) |
| 345 | + |
| 346 | + |
| 347 | +@pytest.fixture(params=["cuda", "cpu"]) |
| 348 | +def mock_device(request): |
| 349 | + """Fixture to create a mock device for testing.""" |
| 350 | + # Create a torch.device object |
| 351 | + device_str = f"{request.param}:0" if request.param == "cuda" else "cpu" |
| 352 | + mock_device = torch.device(device_str) |
| 353 | + |
| 354 | + orig_tensor = torch.tensor |
| 355 | + orig_empty = torch.empty |
| 356 | + |
| 357 | + @wraps(orig_tensor) |
| 358 | + def mock_tensor(data, *args, **kwargs): |
| 359 | + # Force device to CPU |
| 360 | + kwargs["device"] = "cpu" |
| 361 | + tensor = orig_tensor(data, *args, **kwargs) |
| 362 | + tensor.device = mock_device |
| 363 | + return tensor |
| 364 | + |
| 365 | + @wraps(orig_empty) |
| 366 | + def mock_empty(*args, **kwargs): |
| 367 | + kwargs["device"] = "cpu" |
| 368 | + tensor = orig_empty(*args, **kwargs) |
| 369 | + tensor.device = mock_device |
| 370 | + return tensor |
| 371 | + |
| 372 | + if request.param == "cuda": |
| 373 | + mock_properties = type( |
| 374 | + "CudaDeviceProperties", |
| 375 | + (), |
| 376 | + { |
| 377 | + "major": 8, |
| 378 | + "minor": 0, |
| 379 | + "name": "Mocked CUDA Device", |
| 380 | + "total_memory": 8 * 1024 * 1024 * 1024, |
| 381 | + }, |
| 382 | + )() |
| 383 | + |
| 384 | + with ( |
| 385 | + patch("torch.cuda.is_available", return_value=True), |
| 386 | + patch("torch.cuda._lazy_init", return_value=None), |
| 387 | + patch("torch.cuda.device_count", return_value=1), |
| 388 | + patch("torch.cuda.get_device_properties", return_value=mock_properties), |
| 389 | + patch("torch.cuda.get_device_capability", return_value=(8, 0)), |
| 390 | + patch("torch.cuda.set_device", return_value=None), |
| 391 | + patch("torch.empty", new=mock_empty), |
| 392 | + patch("torch.tensor", new=mock_tensor), |
| 393 | + patch( |
| 394 | + "torch.Tensor.to", |
| 395 | + new=lambda self, device, *args, **kwargs: self.clone() |
| 396 | + if isinstance(device, (str, torch.device)) |
| 397 | + and str(device).startswith("cuda") |
| 398 | + else self, |
| 399 | + ), |
| 400 | + patch( |
| 401 | + "torch.Tensor.device", |
| 402 | + new_callable=PropertyMock, |
| 403 | + return_value=mock_device, |
| 404 | + ), |
| 405 | + patch("torch.Tensor.cuda", new=lambda self, *args, **kwargs: self.clone()), |
| 406 | + patch("torch.nn.Module.cuda", new=lambda self, *args, **kwargs: self), |
| 407 | + patch("torch.nn.Module.to", new=lambda self, device, *args, **kwargs: self), |
| 408 | + ): |
| 409 | + yield "cuda" |
| 410 | + else: |
| 411 | + yield "cpu" |
| 412 | + |
| 413 | + |
| 414 | +@pytest.mark.skipif( |
| 415 | + "cpflows" not in _get_installed_packages(), |
| 416 | + reason="cpflows is not installed, skipping MQF2DistributionLoss tests", |
| 417 | +) |
| 418 | +def test_MQF2DistributionLoss_device_handling(mock_device): |
| 419 | + from pytorch_forecasting.metrics import MQF2DistributionLoss |
| 420 | + |
| 421 | + loss = MQF2DistributionLoss(prediction_length=2) |
| 422 | + |
| 423 | + assert next(loss.picnn.parameters()).device.type == mock_device |
| 424 | + |
| 425 | + if mock_device == "cuda": |
| 426 | + loss.cuda() |
| 427 | + assert next(loss.picnn.parameters()).device.type == "cuda" |
| 428 | + elif mock_device == "cpu": |
| 429 | + loss.cpu() |
| 430 | + assert next(loss.picnn.parameters()).device.type == "cpu" |
| 431 | + loss.to(mock_device) |
| 432 | + assert next(loss.picnn.parameters()).device.type == mock_device |
| 433 | + |
| 434 | + |
| 435 | +device_params = [ |
| 436 | + pytest.param( |
| 437 | + "cuda", |
| 438 | + marks=pytest.mark.skipif( |
| 439 | + not torch.cuda.is_available(), reason="CUDA is not available" |
| 440 | + ), |
| 441 | + ), |
| 442 | + "cpu", |
| 443 | +] |
| 444 | + |
| 445 | + |
| 446 | +@pytest.mark.skipif( |
| 447 | + "cpflows" not in _get_installed_packages(), |
| 448 | + reason="cpflows is not installed, skipping MQF2DistributionLoss tests", |
| 449 | +) |
| 450 | +@pytest.mark.parametrize("device", device_params) |
| 451 | +def test_MQF2DistributionLoss_full_workflow(sample_dataset, device): |
| 452 | + """ |
| 453 | + Test the complete workflow from training to prediction with MQF2DistributionLoss. |
| 454 | + """ |
| 455 | + import lightning.pytorch as pl |
| 456 | + |
| 457 | + from pytorch_forecasting.metrics import MQF2DistributionLoss |
| 458 | + |
| 459 | + model = TemporalFusionTransformer.from_dataset( |
| 460 | + sample_dataset, loss=MQF2DistributionLoss(prediction_length=2) |
| 461 | + ) |
| 462 | + |
| 463 | + trainer = pl.Trainer( |
| 464 | + max_epochs=1, |
| 465 | + accelerator=device, |
| 466 | + devices="auto", |
| 467 | + gradient_clip_val=0.1, |
| 468 | + limit_train_batches=30, |
| 469 | + limit_val_batches=3, |
| 470 | + ) |
| 471 | + dataloader = sample_dataset.to_dataloader(train=True, batch_size=4, num_workers=0) |
| 472 | + |
| 473 | + trainer.fit(model, dataloader) |
| 474 | + |
| 475 | + raw_predictions = model.predict( |
| 476 | + dataloader, |
| 477 | + mode="raw", |
| 478 | + return_x=True, |
| 479 | + trainer_kwargs=dict(accelerator=device, devices="auto", logger=False), |
| 480 | + ) |
| 481 | + # Verify predictions are on correct device |
| 482 | + pred_device = raw_predictions.output["prediction"].device.type |
| 483 | + target_device = raw_predictions.x["encoder_target"].device.type |
| 484 | + assert pred_device == device |
| 485 | + assert target_device == device |
| 486 | + try: |
| 487 | + model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0) |
| 488 | + plot_success = True |
| 489 | + except RuntimeError as e: |
| 490 | + if "device" in str(e).lower() or "expected" in str(e).lower(): |
| 491 | + plot_success = False |
| 492 | + pytest.fail(f"Device mismatch error during plotting: {e}") |
| 493 | + else: |
| 494 | + raise e |
| 495 | + assert plot_success, "Plotting failed due to device mismatch" |
| 496 | + |
| 497 | + |
| 498 | +@pytest.mark.skipif( |
| 499 | + "cpflows" not in _get_installed_packages(), |
| 500 | + reason="cpflows is not installed, skipping MQF2DistributionLoss tests", |
| 501 | +) |
| 502 | +def test_MQF2DistributionLoss_device_synchronization(mock_device, sample_dataset): |
| 503 | + """Test that MQF2DistributionLoss components are synchronized with the device.""" |
| 504 | + from pytorch_forecasting.metrics import MQF2DistributionLoss |
| 505 | + |
| 506 | + model = TemporalFusionTransformer.from_dataset( |
| 507 | + sample_dataset, loss=MQF2DistributionLoss(prediction_length=2) |
| 508 | + ) |
| 509 | + fake_prediction = torch.randn(4, 2, 8) |
| 510 | + |
| 511 | + if mock_device == "cuda": |
| 512 | + fake_prediction = fake_prediction.cuda() |
| 513 | + model.loss.map_x_to_distribution(fake_prediction) |
| 514 | + assert next(model.loss.picnn.parameters()).device.type == "cuda" |
| 515 | + if mock_device == "cpu": |
| 516 | + fake_prediction = fake_prediction.cpu() |
| 517 | + model.loss.map_x_to_distribution(fake_prediction) |
| 518 | + assert next(model.loss.picnn.parameters()).device.type == "cpu" |
| 519 | + |
| 520 | + |
309 | 521 | def test_CrossEntropyLoss():
|
310 | 522 | batch_size = 3
|
311 | 523 | n_timesteps = 5
|
|
0 commit comments