Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
252598d
D1, D2 layer commit
phoeenniixx Apr 6, 2025
d0d1c3e
remove one comment
phoeenniixx Apr 6, 2025
80e64d2
model layer commit
phoeenniixx Apr 6, 2025
6364780
update docstring
phoeenniixx Apr 6, 2025
82b3dc7
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 6, 2025
257183c
update data_module.py
phoeenniixx Apr 10, 2025
9cdcb19
update data_module.py
phoeenniixx Apr 10, 2025
a83bf32
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
ac56d4f
Add disclaimer
phoeenniixx Apr 10, 2025
0e7e36f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
4bfff21
update docstring
phoeenniixx Apr 11, 2025
ef98273
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 11, 2025
8a53ed6
Add tests for D1,D2 layer
phoeenniixx Apr 19, 2025
9f9df31
Merge branch 'main' into refactor-d1-d2
phoeenniixx Apr 19, 2025
cdecb77
Code quality
phoeenniixx Apr 19, 2025
86360fd
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 19, 2025
20aafb7
refactor file
fkiraly Apr 30, 2025
043820d
warning
fkiraly Apr 30, 2025
1720a15
linting
fkiraly May 1, 2025
af44474
move coercion to utils
fkiraly May 1, 2025
a3cb8b7
linting
fkiraly May 1, 2025
75d7fb5
Update _timeseries_v2.py
fkiraly May 1, 2025
1b946e6
Update __init__.py
fkiraly May 1, 2025
3edb08b
Update __init__.py
fkiraly May 1, 2025
a4bc9d8
Merge branch 'main' into pr/1811
fkiraly May 1, 2025
4c0d570
Merge branch 'pr/1811' into pr/1812
fkiraly May 1, 2025
e350291
update tests
phoeenniixx May 11, 2025
f90c94f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 11, 2025
3099691
update tft_v2
phoeenniixx May 11, 2025
77cb979
warnings and init attr handling
fkiraly May 13, 2025
28df3c3
Merge branch 'refactor-d1-d2' of https://github.com/phoeenniixx/pytor…
fkiraly May 13, 2025
f8c94e6
simplify TimeSeries.__getitem__
fkiraly May 13, 2025
c289255
Update _timeseries_v2.py
fkiraly May 13, 2025
9467f38
Update data_module.py
fkiraly May 13, 2025
c3b40ad
backwards compat of private/public attrs
fkiraly May 13, 2025
c007310
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 13, 2025
2e25052
Merge branch 'main' into refactor-model
phoeenniixx May 13, 2025
38c28dc
add tests
phoeenniixx May 14, 2025
9d80eb8
add tests
phoeenniixx May 14, 2025
a8ccfe3
add tests
phoeenniixx May 14, 2025
f900ba5
add more docstrings
phoeenniixx May 14, 2025
ed1b799
add note about the commented out tests
phoeenniixx May 14, 2025
c947910
Merge branch 'main' into refactor-model
phoeenniixx May 16, 2025
c0ceb8a
add the commented out tests
phoeenniixx May 16, 2025
3828c26
remove note
phoeenniixx May 16, 2025
6d6d18e
Merge branch 'main' into refactor-model
phoeenniixx May 18, 2025
30b541b
make the modules private
phoeenniixx May 20, 2025
3f1e11f
Merge remote-tracking branch 'origin/refactor-model' into refactor-model
phoeenniixx May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 296 additions & 0 deletions pytorch_forecasting/models/base/_base_model_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
########################################################################################
# Disclaimer: This baseclass is still work in progress and experimental, please
# use with care. This class is a basic skeleton of how the base classes may look like
# in the version-2.
########################################################################################


from typing import Dict, List, Optional, Tuple, Union
from warnings import warn

Check warning on line 9 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L8-L9

Added lines #L8 - L9 were not covered by tests

from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
import torch
import torch.nn as nn
from torch.optim import Optimizer

Check warning on line 15 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L11-L15

Added lines #L11 - L15 were not covered by tests


class BaseModel(LightningModule):
def __init__(

Check warning on line 19 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L18-L19

Added lines #L18 - L19 were not covered by tests
self,
loss: nn.Module,
logging_metrics: Optional[List[nn.Module]] = None,
optimizer: Optional[Union[Optimizer, str]] = "adam",
optimizer_params: Optional[Dict] = None,
lr_scheduler: Optional[str] = None,
lr_scheduler_params: Optional[Dict] = None,
):
"""
Base model for time series forecasting.

Parameters
----------
loss : nn.Module
Loss function to use for training.
logging_metrics : Optional[List[nn.Module]], optional
List of metrics to log during training, validation, and testing.
optimizer : Optional[Union[Optimizer, str]], optional
Optimizer to use for training.
Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`.
optimizer_params : Optional[Dict], optional
Parameters for the optimizer.
lr_scheduler : Optional[str], optional
Learning rate scheduler to use.
Supported values: "reduce_lr_on_plateau", "step_lr".
lr_scheduler_params : Optional[Dict], optional
Parameters for the learning rate scheduler.
"""
super().__init__()
self.loss = loss
self.logging_metrics = logging_metrics if logging_metrics is not None else []
self.optimizer = optimizer
self.optimizer_params = optimizer_params if optimizer_params is not None else {}
self.lr_scheduler = lr_scheduler
self.lr_scheduler_params = (

Check warning on line 54 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L48-L54

Added lines #L48 - L54 were not covered by tests
lr_scheduler_params if lr_scheduler_params is not None else {}
)
self.model_name = self.__class__.__name__
warn(

Check warning on line 58 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L57-L58

Added lines #L57 - L58 were not covered by tests
f"The Model '{self.model_name}' is part of an experimental rework"
"of the pytorch-forecasting model layer, scheduled for release with v2.0.0."
" The API is not stable and may change without prior warning. "
"This class is intended for beta testing and as a basic skeleton, "
"but not for stable production use. "
"Feedback and suggestions are very welcome in "
"pytorch-forecasting issue 1736, "
"https://github.com/sktime/pytorch-forecasting/issues/1736",
UserWarning,
)

def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

Check warning on line 70 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L70

Added line #L70 was not covered by tests
"""
Forward pass of the model.

Parameters
----------
x : Dict[str, torch.Tensor]
Dictionary containing input tensors

Returns
-------
Dict[str, torch.Tensor]
Dictionary containing output tensors
"""
raise NotImplementedError("Forward method must be implemented by subclass.")

Check warning on line 84 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L84

Added line #L84 was not covered by tests

def training_step(

Check warning on line 86 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L86

Added line #L86 was not covered by tests
self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Training step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.

Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(

Check warning on line 108 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L104-L108

Added lines #L104 - L108 were not covered by tests
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="train")
return {"loss": loss}

Check warning on line 112 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L111-L112

Added lines #L111 - L112 were not covered by tests

def validation_step(

Check warning on line 114 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L114

Added line #L114 was not covered by tests
self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Validation step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.

Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(

Check warning on line 136 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L132-L136

Added lines #L132 - L136 were not covered by tests
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="val")
return {"val_loss": loss}

Check warning on line 140 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L139-L140

Added lines #L139 - L140 were not covered by tests

def test_step(

Check warning on line 142 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L142

Added line #L142 was not covered by tests
self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Test step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.

Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(

Check warning on line 164 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L160-L164

Added lines #L160 - L164 were not covered by tests
"test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="test")
return {"test_loss": loss}

Check warning on line 168 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L167-L168

Added lines #L167 - L168 were not covered by tests

def predict_step(

Check warning on line 170 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L170

Added line #L170 was not covered by tests
self,
batch: Tuple[Dict[str, torch.Tensor]],
batch_idx: int,
dataloader_idx: int = 0,
) -> torch.Tensor:
"""
Prediction step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input tensors.
batch_idx : int
Index of the batch.
dataloader_idx : int
Index of the dataloader.

Returns
-------
torch.Tensor
Predicted output tensor.
"""
x, _ = batch
y_hat = self(x)
return y_hat

Check warning on line 195 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L193-L195

Added lines #L193 - L195 were not covered by tests

def configure_optimizers(self) -> Dict:

Check warning on line 197 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L197

Added line #L197 was not covered by tests
"""
Configure the optimizer and learning rate scheduler.

Returns
-------
Dict
Dictionary containing the optimizer and scheduler configuration.
"""
optimizer = self._get_optimizer()
if self.lr_scheduler is not None:
scheduler = self._get_scheduler(optimizer)
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
return {

Check warning on line 210 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L206-L210

Added lines #L206 - L210 were not covered by tests
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
},
}
else:
return {"optimizer": optimizer, "lr_scheduler": scheduler}
return {"optimizer": optimizer}

Check warning on line 219 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L218-L219

Added lines #L218 - L219 were not covered by tests

def _get_optimizer(self) -> Optimizer:

Check warning on line 221 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L221

Added line #L221 was not covered by tests
"""
Get the optimizer based on the specified optimizer name and parameters.

Returns
-------
Optimizer
The optimizer instance.
"""
if isinstance(self.optimizer, str):
if self.optimizer.lower() == "adam":
return torch.optim.Adam(self.parameters(), **self.optimizer_params)
elif self.optimizer.lower() == "sgd":
return torch.optim.SGD(self.parameters(), **self.optimizer_params)

Check warning on line 234 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L230-L234

Added lines #L230 - L234 were not covered by tests
else:
raise ValueError(f"Optimizer {self.optimizer} not supported.")
elif isinstance(self.optimizer, Optimizer):
return self.optimizer

Check warning on line 238 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L236-L238

Added lines #L236 - L238 were not covered by tests
else:
raise ValueError(

Check warning on line 240 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L240

Added line #L240 was not covered by tests
"Optimizer must be either a string or "
"an instance of torch.optim.Optimizer."
)

def _get_scheduler(

Check warning on line 245 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L245

Added line #L245 was not covered by tests
self, optimizer: Optimizer
) -> torch.optim.lr_scheduler._LRScheduler:
"""
Get the lr scheduler based on the specified scheduler name and params.

Parameters
----------
optimizer : Optimizer
The optimizer instance.

Returns
-------
torch.optim.lr_scheduler._LRScheduler
The learning rate scheduler instance.
"""
if self.lr_scheduler.lower() == "reduce_lr_on_plateau":
return torch.optim.lr_scheduler.ReduceLROnPlateau(

Check warning on line 262 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L261-L262

Added lines #L261 - L262 were not covered by tests
optimizer, **self.lr_scheduler_params
)
elif self.lr_scheduler.lower() == "step_lr":
return torch.optim.lr_scheduler.StepLR(

Check warning on line 266 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L265-L266

Added lines #L265 - L266 were not covered by tests
optimizer, **self.lr_scheduler_params
)
else:
raise ValueError(f"Scheduler {self.lr_scheduler} not supported.")

Check warning on line 270 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L270

Added line #L270 was not covered by tests

def log_metrics(

Check warning on line 272 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L272

Added line #L272 was not covered by tests
self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val"
) -> None:
"""
Log additional metrics during training, validation, or testing.

Parameters
----------
y_hat : torch.Tensor
Predicted output tensor.
y : torch.Tensor
Target output tensor.
prefix : str
Prefix for the logged metrics (e.g., "train", "val", "test").
"""
for metric in self.logging_metrics:
metric_value = metric(y_hat, y)
self.log(

Check warning on line 289 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L287-L289

Added lines #L287 - L289 were not covered by tests
f"{prefix}_{metric.__class__.__name__}",
metric_value,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
Loading
Loading