diff --git a/CHANGELOG.md b/CHANGELOG.md index 1351417aa..3e6b1288f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Release Notes +## v1.1.0 Adding TSMixer + +### Added + +- New state-of-the-art model beating TFT called TSMixer based on [TSMixer: An All-MLP Architecture for Time Series Forecasting](https://arxiv.org/abs/2303.06053). + +### Fixes + +- Multiple small fixes + +### Contributors + +- jdb78 +- jurgispods +- jacktang +- andre-marcos-perez +- tmxt +- bohdan-safoniuk +- maartensukel +- CahidArda +- MBelniak + ## v1.0.0 Update to pytorch 2.0 (10/04/2023) diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index 51ee5b06c..631f4915f 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -43,6 +43,7 @@ NHiTS, RecurrentNetwork, TemporalFusionTransformer, + TSMixer, get_rnn, ) from pytorch_forecasting.utils import ( @@ -69,6 +70,7 @@ "NBeats", "NHiTS", "Baseline", + "TSMixer", "DeepAR", "BaseModel", "BaseModelWithCovariates", diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index adbe79338..363384b96 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -15,6 +15,7 @@ from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn from pytorch_forecasting.models.rnn import RecurrentNetwork from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer +from pytorch_forecasting.models.tsmixer import TSMixer __all__ = [ "NBeats", @@ -32,4 +33,5 @@ "GRU", "MultiEmbedding", "DecoderMLP", + "TSMixer", ] diff --git a/pytorch_forecasting/models/tsmixer/__init__.py b/pytorch_forecasting/models/tsmixer/__init__.py new file mode 100644 index 000000000..daea21b37 --- /dev/null +++ b/pytorch_forecasting/models/tsmixer/__init__.py @@ -0,0 +1,190 @@ +""" +TSMixer is a fairly simple architecture shown to have beaten the likes of the Temporal Fusion Transformer. + +Reference: `TSMixer: An All-MLP Architecture for Time Series Forecasting `_ +""" + +from copy import copy +from typing import Dict, List, Tuple, Union + +from matplotlib import pyplot as plt +import numpy as np +import torch +from torch import nn +from torchmetrics import Metric as LightningMetric + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss, QuantileLoss +from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.nn import LSTM, MultiEmbedding +from pytorch_forecasting.models.tsmixer.submodules import TSMixerEncoder +from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list + + +class TSMixer(BaseModelWithCovariates): + def __init__( + self, + hidden_size: int = 16, + lstm_layers: int = 1, + dropout: float = 0.1, + output_size: Union[int, List[int]] = 7, + loss: MultiHorizonMetric = None, + attention_head_size: int = 4, + max_encoder_length: int = 10, + static_categoricals: List[str] = [], + static_reals: List[str] = [], + time_varying_categoricals_encoder: List[str] = [], + time_varying_categoricals_decoder: List[str] = [], + categorical_groups: Dict[str, List[str]] = {}, + time_varying_reals_encoder: List[str] = [], + time_varying_reals_decoder: List[str] = [], + x_reals: List[str] = [], + x_categoricals: List[str] = [], + hidden_continuous_size: int = 8, + hidden_continuous_sizes: Dict[str, int] = {}, + embedding_sizes: Dict[str, Tuple[int, int]] = {}, + embedding_paddings: List[str] = [], + embedding_labels: Dict[str, np.ndarray] = {}, + learning_rate: float = 1e-3, + log_interval: Union[int, float] = -1, + log_val_interval: Union[int, float] = None, + log_gradient_flow: bool = False, + reduce_on_plateau_patience: int = 1000, + monotone_constaints: Dict[str, int] = {}, + share_single_variable_networks: bool = False, + causal_attention: bool = True, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible. + + Implementation of the article + `TSMixer: An All-MLP Architecture for Time Series Forecasting `_ + Args: + + hidden_size: hidden size of network which is its main hyperparameter and can range from 8 to 512 + lstm_layers: number of LSTM layers (2 is mostly optimal) + dropout: dropout rate + output_size: number of outputs (e.g. number of quantiles for QuantileLoss and one target or list + of output sizes). + loss: loss function taking prediction and targets + attention_head_size: number of attention heads (4 is a good default) + max_encoder_length: length to encode (can be far longer than the decoder length but does not have to be) + static_categoricals: names of static categorical variables + static_reals: names of static continuous variables + time_varying_categoricals_encoder: names of categorical variables for encoder + time_varying_categoricals_decoder: names of categorical variables for decoder + time_varying_reals_encoder: names of continuous variables for encoder + time_varying_reals_decoder: names of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical + embedding size) + hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection + (fallback to hidden_continuous_size if index is not in dictionary) + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + learning_rate: learning rate + log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 + , will log multiple entries per batch. Defaults to -1. + log_val_interval: frequency with which to log validation set metrics, defaults to log_interval + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 + monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder + variables mapping + position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive, + larger numbers add more weight to the constraint vs. the loss but are usually not necessary). + This constraint significantly slows down training. Defaults to {}. + share_single_variable_networks (bool): if to share the single variable networks between the encoder and + decoder. Defaults to False. + causal_attention (bool): If to attend only at previous timesteps in the decoder or also include future + predictions. Defaults to True. + logging_metrics (nn.ModuleList[LightningMetric]): list of metrics that are logged during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]). + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]) + if loss is None: + loss = QuantileLoss() + self.save_hyperparameters() + # store loss function separately as it is a module + assert isinstance(loss, LightningMetric), "Loss has to be a PyTorch Lightning `Metric`" + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # processing inputs + # embeddings + self.input_embeddings = MultiEmbedding( + embedding_sizes=self.hparams.embedding_sizes, + categorical_groups=self.hparams.categorical_groups, + embedding_paddings=self.hparams.embedding_paddings, + x_categoricals=self.hparams.x_categoricals, + max_embedding_size=self.hparams.hidden_size, + ) + + @classmethod + def from_dataset( + cls, + dataset: TimeSeriesDataSet, + allowed_encoder_known_variable_names: List[str] = None, + **kwargs, + ): + """ + Create model from dataset. + + Args: + dataset: timeseries dataset + allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all + **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) + + Returns: + TemporalFusionTransformer + """ + # add maximum encoder length + # update defaults + new_kwargs = copy(kwargs) + new_kwargs["max_encoder_length"] = dataset.max_encoder_length + new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss())) + + # create class and return + return super().from_dataset( + dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + input dimensions: n_samples x time x variables + """ + encoder_lengths = x["encoder_lengths"] + decoder_lengths = x["decoder_lengths"] + x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension + x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1) # concatenate in time dimension + # timesteps = x_cont.size(1) # encode + decode length + # max_encoder_length = int(encoder_lengths.max()) + input_vectors = self.input_embeddings(x_cat) + input_vectors.update( + { + name: x_cont[..., idx].unsqueeze(-1) + for idx, name in enumerate(self.hparams.x_reals) + if name in self.reals + } + ) + + return self.to_network_output( + # prediction=self.transform_output(output, target_scale=x["target_scale"]), + # encoder_attention=attn_output_weights[..., :max_encoder_length], + # decoder_attention=attn_output_weights[..., max_encoder_length:], + # static_variables=static_variable_selection, + # encoder_variables=encoder_sparse_weights, + # decoder_variables=decoder_sparse_weights, + decoder_lengths=decoder_lengths, + encoder_lengths=encoder_lengths, + ) diff --git a/pytorch_forecasting/models/tsmixer/submodules.py b/pytorch_forecasting/models/tsmixer/submodules.py new file mode 100644 index 000000000..14be30143 --- /dev/null +++ b/pytorch_forecasting/models/tsmixer/submodules.py @@ -0,0 +1,223 @@ +from typing import Optional + +import torch +from torch import nn +import torch.nn.functional as F + + +class TemporalLinear(nn.Module): + def __init__( + self, + input_len: int, + output_len: int, + activation: Optional[str] = None, + dropout: Optional[float] = 0, + ): + super().__init__() + self.linear = nn.Linear(in_features=input_len, out_features=output_len) + self.activation = None if activation is None else getattr(F, activation) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x.permute(0, 2, 1)).permute(0, 2, 1) + x = x if self.activation is None else self.activation(x) + x = self.dropout(x) + return x + + +class TemporalResBlock(nn.Module): + def __init__( + self, + input_len: int, + input_size: int, + activation: Optional[str] = None, + dropout: Optional[float] = 0, + ): + super().__init__() + self.temporal_linear = TemporalLinear(input_len, input_len, activation, dropout) + self.norm = nn.LayerNorm(normalized_shape=(input_len, input_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = x + x = self.temporal_linear(x) + return self.norm(res + x) + + +class FeaturalResBlock(nn.Module): + def __init__( + self, + input_len: int, + input_size: int, + hidden_size: int, + output_size: int, + activation: Optional[str] = "relu", + dropout: Optional[float] = 0, + ): + super().__init__() + self.linear1 = nn.Linear(in_features=input_size, out_features=hidden_size) + self.linear2 = nn.Linear(in_features=hidden_size, out_features=output_size) + self.res_linear = None + if input_size != output_size: + self.res_linear = nn.Linear(in_features=input_size, out_features=output_size) + self.activation = getattr(F, activation) + self.dropout = nn.Dropout(p=dropout) + self.norm = nn.LayerNorm(normalized_shape=(input_len, output_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = x if self.res_linear is None else self.res_linear(x) + x = self.linear1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.linear2(x) + x = self.dropout(x) + return self.norm(res + x) + + +class ConditionalFeaturalResBlock(nn.Module): + def __init__( + self, + input_len: int, + input_size: int, + hidden_size: int, + output_size: int, + static_size: int, + activation: Optional[str] = "relu", + dropout: Optional[float] = 0, + ): + super().__init__() + self.input_len = input_len + self.static_block = FeaturalResBlock(1, static_size, hidden_size, hidden_size, activation, dropout) + self.block = FeaturalResBlock( + input_len, + input_size + hidden_size, + hidden_size, + output_size, + activation, + dropout, + ) + + def forward(self, x: torch.Tensor, static: torch.Tensor) -> torch.Tensor: + static = self.static_block(static.unsqueeze(1)) + static = torch.repeat_interleave(static, self.input_len, dim=1) + x = torch.concat([x, static], dim=2) + x = self.block(x) + return x + + +class MixerBlock(nn.Module): + def __init__( + self, + input_len: int, + input_size: int, + hidden_size: int, + output_size: int, + activation: str, + dropout: float, + ): + super().__init__() + self.temporal_res_block = TemporalResBlock(input_len, input_size, activation, dropout) + self.ffwd_res_block = FeaturalResBlock( + input_len, + input_size, + hidden_size, + output_size, + activation, + dropout, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.temporal_res_block(x) + x = self.ffwd_res_block(x) + return x + + +class ConditionalMixerBlock(nn.Module): + def __init__( + self, + input_len: int, + input_size: int, + hidden_size: int, + output_size: int, + static_size: int, + activation: str, + dropout: float, + ): + super().__init__() + self.temporal_res_block = TemporalResBlock(input_len, input_size, activation, dropout) + self.ffwd_res_block = ConditionalFeaturalResBlock( + input_len, + input_size, + hidden_size, + output_size, + static_size, + activation, + dropout, + ) + + def forward(self, x: torch.Tensor, static: torch.Tensor) -> torch.Tensor: + x = self.temporal_res_block(x) + x = self.ffwd_res_block(x, static) + return x + + +class TSMixerEncoder(nn.Module): + def __init__( + self, + input_len: int, + output_len: int, + past_feat_size: int, + future_feat_size: int, + static_feat_size: int, + hidden_size: int, + activation: str, + dropout: float, + n_block: Optional[int] = 1, + ): + super().__init__() + self.past_temporal_linear = TemporalLinear(input_len, output_len) + self.past_featural_block = ConditionalFeaturalResBlock( + input_len=output_len, + input_size=past_feat_size, + hidden_size=hidden_size, + output_size=hidden_size, + static_size=static_feat_size, + activation=activation, + dropout=dropout, + ) + self.future_featural_block = ConditionalFeaturalResBlock( + input_len=output_len, + input_size=future_feat_size, + hidden_size=hidden_size, + output_size=hidden_size, + static_size=static_feat_size, + activation=activation, + dropout=dropout, + ) + self.blocks = nn.ModuleList( + [ + ConditionalMixerBlock( + input_len=output_len, + input_size=(2 * hidden_size) if i == 0 else hidden_size, + hidden_size=hidden_size, + output_size=hidden_size, + static_size=static_feat_size, + activation=activation, + dropout=dropout, + ) + for i in range(n_block) + ] + ) + + def forward( + self, + past_feature: torch.Tensor, + future_feature: torch.Tensor, + static_feature: torch.Tensor, + ) -> torch.Tensor: + past_feature = self.past_temporal_linear(past_feature) + past_feature = self.past_featural_block(past_feature, static_feature) + future_feature = self.future_featural_block(future_feature, static_feature) + x = torch.cat([past_feature, future_feature], dim=2) + for block in self.blocks: + x = block(x, static_feature) + return x diff --git a/tests/test_models/test_tsmixer.py b/tests/test_models/test_tsmixer.py new file mode 100644 index 000000000..645327a08 --- /dev/null +++ b/tests/test_models/test_tsmixer.py @@ -0,0 +1,174 @@ +import pickle +import shutil +import sys + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import pytest +import torch + +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.encoders import MultiNormalizer +from pytorch_forecasting.metrics import CrossEntropy, MQF2DistributionLoss, MultiLoss, PoissonLoss, QuantileLoss +from pytorch_forecasting.models import TSMixer + +if sys.version.startswith("3.6"): # python 3.6 does not have nullcontext + from contextlib import contextmanager + + @contextmanager + def nullcontext(enter_result=None): + yield enter_result + +else: + from contextlib import nullcontext + + +def test_integration(multiple_dataloaders_with_covariates, tmp_path): + _integration(multiple_dataloaders_with_covariates, tmp_path, trainer_kwargs=dict(accelerator="cpu")) + + +def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs): + train_dataloader = dataloader["train"] + val_dataloader = dataloader["val"] + test_dataloader = dataloader["test"] + + early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") + + # check training + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=2, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs + ) + # test monotone constraints automatically + if "discount_in_percent" in train_dataloader.dataset.reals: + monotone_constaints = {"discount_in_percent": +1} + cuda_context = torch.backends.cudnn.flags(enabled=False) + else: + monotone_constaints = {} + cuda_context = nullcontext() + + kwargs.setdefault("learning_rate", 0.15) + + with cuda_context: + if loss is not None: + pass + elif isinstance(train_dataloader.dataset.target_normalizer, NaNLabelEncoder): + loss = CrossEntropy() + elif isinstance(train_dataloader.dataset.target_normalizer, MultiNormalizer): + loss = MultiLoss( + [ + CrossEntropy() if isinstance(normalizer, NaNLabelEncoder) else QuantileLoss() + for normalizer in train_dataloader.dataset.target_normalizer.normalizers + ] + ) + else: + loss = QuantileLoss() + net = TSMixer.from_dataset( + train_dataloader.dataset, + hidden_size=2, + hidden_continuous_size=2, + attention_head_size=1, + dropout=0.2, + loss=loss, + log_interval=5, + log_val_interval=1, + log_gradient_flow=True, + monotone_constaints=monotone_constaints, + **kwargs + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + # todo: testing somehow disables grad computation even though it is explicitly turned on - + # loss is calculated as "grad" for MQF2 + if not isinstance(net.loss, MQF2DistributionLoss): + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + + # check loading + net = TSMixer.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + predictions = net.predict( + val_dataloader, + return_index=True, + return_x=True, + return_y=True, + fast_dev_run=True, + trainer_kwargs=trainer_kwargs, + ) + pred_len = len(predictions.index) + + # check that output is of correct shape + def check(x): + if isinstance(x, (tuple, list)): + for xi in x: + check(xi) + elif isinstance(x, dict): + for xi in x.values(): + check(xi) + else: + assert pred_len == x.shape[0], "first dimension should be prediction length" + + check(predictions.output) + if isinstance(predictions.output, torch.Tensor): + assert predictions.output.ndim == 2, "shape of predictions should be batch_size x timesteps" + else: + assert all( + p.ndim == 2 for p in predictions.output + ), "shape of predictions should be batch_size x timesteps" + check(predictions.x) + check(predictions.index) + + # predict raw + net.predict( + val_dataloader, + return_index=True, + return_x=True, + fast_dev_run=True, + mode="raw", + trainer_kwargs=trainer_kwargs, + ) + + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + +@pytest.fixture +def model(dataloaders_with_covariates): + dataset = dataloaders_with_covariates["train"].dataset + net = TSMixer.from_dataset( + dataset, + learning_rate=0.15, + hidden_size=4, + attention_head_size=1, + dropout=0.2, + hidden_continuous_size=2, + loss=PoissonLoss(), + output_size=1, + log_interval=5, + log_val_interval=1, + log_gradient_flow=True, + ) + return net + + +def test_pickle(model): + pkl = pickle.dumps(model) + pickle.loads(pkl)