From 252598d2ce3f31244a422cd9206961776ea79615 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:43:51 +0530 Subject: [PATCH 01/33] D1, D2 layer commit --- pytorch_forecasting/data/data_module.py | 633 ++++++++++++++++++++++++ pytorch_forecasting/data/timeseries.py | 257 ++++++++++ 2 files changed, 890 insertions(+) create mode 100644 pytorch_forecasting/data/data_module.py diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py new file mode 100644 index 000000000..56917696d --- /dev/null +++ b/pytorch_forecasting/data/data_module.py @@ -0,0 +1,633 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + super().__init__() + self.time_series_dataset = time_series_dataset + self.time_series_metadata = time_series_dataset.get_metadata() + + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length or max_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_idx = min_prediction_idx + + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self.target_normalizer = RobustScaler() + else: + self.target_normalizer = target_normalizer + + self.categorical_encoders = _coerce_to_dict(categorical_encoders) + self.scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + self._metadata = None + + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + def _prepare_metadata(self): + """Prepare metadata for model initialisation. + + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. + + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self.min_encoder_length, + "min_prediction_length": self.min_prediction_length, + } + ) + + return metadata + + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ + processed_data = [] + + for idx in indices: + sample = self.time_series_dataset[idx.item()] + + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + # TODO: add scalers, target normalizers etc. + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + processed_data.append( + { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } + ) + + return processed_data + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + processed_data : List[Dict[str, Any]] + List of preprocessed time series samples. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + processed_data: List[Dict[str, Any]], + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.processed_data = processed_data + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.processed_data[series_idx] + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: + target_scale = torch.tensor(1.0) + + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + + x = { + "encoder_cat": data["features"]["categorical"][encoder_indices], + "encoder_cont": data["features"]["continuous"][encoder_indices], + "decoder_cat": data["features"]["categorical"][decoder_indices], + "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, + } + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows( + self, processed_data: List[Dict[str, Any]] + ) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `processed_data`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ + windows = [] + + for idx, data in enumerate(processed_data): + sequence_length = data["length"] + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares all datasets. + """ + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_processed = self._preprocess_data(self._train_indices) + self.val_processed = self._preprocess_data(self._val_indices) + + self.train_windows = self._create_windows(self.train_processed) + self.val_windows = self._create_windows(self.val_processed) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.train_processed, self.train_windows, self.add_relative_time_idx + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.val_processed, self.val_windows, self.add_relative_time_idx + ) + # print(self.val_dataset[0]) + + elif stage is None or stage == "test": + if not hasattr(self, "test_dataset"): + self.test_processed = self._preprocess_data(self._test_indices) + self.test_windows = self._create_windows(self.test_processed) + + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.test_processed, self.test_windows, self.add_relative_time_idx + ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_processed = self._preprocess_data(predict_indices) + self.predict_windows = self._create_windows(self.predict_processed) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.predict_processed, self.predict_windows, self.add_relative_time_idx + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), + } + + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..bc8300300 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2657,6 +2657,8 @@ def _coerce_to_list(obj): """ if obj is None: return [] + if isinstance(obj, str): + return [obj] return list(obj) @@ -2668,3 +2670,258 @@ def _coerce_to_dict(obj): if obj is None: return {} return deepcopy(obj) + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + It returns: + + * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + * ``cutoff_time``: float or ``numpy.float64`` + Cutoff time for the time series instance. + + Optionally, the following str-keyed entry can be included: + + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From d0d1c3ec7fb3bdee8e80d9ff83cd43e8990a5319 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:47:46 +0530 Subject: [PATCH 02/33] remove one comment --- pytorch_forecasting/data/data_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 56917696d..2958f1705 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -550,7 +550,6 @@ def setup(self, stage: Optional[str] = None): self.val_dataset = self._ProcessedEncoderDecoderDataset( self.val_processed, self.val_windows, self.add_relative_time_idx ) - # print(self.val_dataset[0]) elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): From 80e64d218a744557bd493ea07547f0f42b029573 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:07:01 +0530 Subject: [PATCH 03/33] model layer commit --- .../models/base/base_model_refactor.py | 283 ++++++++++++++++++ .../tft_version_two.py | 218 ++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 pytorch_forecasting/models/base/base_model_refactor.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/base_model_refactor.py new file mode 100644 index 000000000..ccd2c2600 --- /dev/null +++ b/pytorch_forecasting/models/base/base_model_refactor.py @@ -0,0 +1,283 @@ +######################################################################################## +# Disclaimer: This baseclass is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the base classes may look like +# in the version-2. +######################################################################################## + + +from typing import Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class BaseModel(LightningModule): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + ): + """ + Base model for time series forecasting. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. + Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + """ + super().__init__() + self.loss = loss + self.logging_metrics = logging_metrics if logging_metrics is not None else [] + self.optimizer = optimizer + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + self.lr_scheduler = lr_scheduler + self.lr_scheduler_params = ( + lr_scheduler_params if lr_scheduler_params is not None else {} + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors + """ + raise NotImplementedError("Forward method must be implemented by subclass.") + + def training_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Training step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="train") + return {"loss": loss} + + def validation_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Validation step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="val") + return {"val_loss": loss} + + def test_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Test step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="test") + return {"test_loss": loss} + + def predict_step( + self, + batch: Tuple[Dict[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> torch.Tensor: + """ + Prediction step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input tensors. + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader. + + Returns + ------- + torch.Tensor + Predicted output tensor. + """ + x, _ = batch + y_hat = self(x) + return y_hat + + def configure_optimizers(self) -> Dict: + """ + Configure the optimizer and learning rate scheduler. + + Returns + ------- + Dict + Dictionary containing the optimizer and scheduler configuration. + """ + optimizer = self._get_optimizer() + if self.lr_scheduler is not None: + scheduler = self._get_scheduler(optimizer) + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + else: + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def _get_optimizer(self) -> Optimizer: + """ + Get the optimizer based on the specified optimizer name and parameters. + + Returns + ------- + Optimizer + The optimizer instance. + """ + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adam": + return torch.optim.Adam(self.parameters(), **self.optimizer_params) + elif self.optimizer.lower() == "sgd": + return torch.optim.SGD(self.parameters(), **self.optimizer_params) + else: + raise ValueError(f"Optimizer {self.optimizer} not supported.") + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + else: + raise ValueError( + "Optimizer must be either a string or " + "an instance of torch.optim.Optimizer." + ) + + def _get_scheduler( + self, optimizer: Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + """ + Get the lr scheduler based on the specified scheduler name and params. + + Parameters + ---------- + optimizer : Optimizer + The optimizer instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler instance. + """ + if self.lr_scheduler.lower() == "reduce_lr_on_plateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, **self.lr_scheduler_params + ) + elif self.lr_scheduler.lower() == "step_lr": + return torch.optim.lr_scheduler.StepLR( + optimizer, **self.lr_scheduler_params + ) + else: + raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") + + def log_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" + ) -> None: + """ + Log additional metrics during training, validation, or testing. + + Parameters + ---------- + y_hat : torch.Tensor + Predicted output tensor. + y : torch.Tensor + Target output tensor. + prefix : str + Prefix for the logged metrics (e.g., "train", "val", "test"). + """ + for metric in self.logging_metrics: + metric_value = metric(y_hat, y) + self.log( + f"{prefix}_{metric.__class__.__name__}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py new file mode 100644 index 000000000..30f70f98e --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -0,0 +1,218 @@ +######################################################################################## +# Disclaimer: This implementation is based on the new version of data pipeline and is +# experimental, please use with care. +######################################################################################## + +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base.base_model_refactor import BaseModel + + +class TFT(BaseModel): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + hidden_size: int = 64, + num_layers: int = 2, + attention_head_size: int = 4, + dropout: float = 0.1, + metadata: Optional[Dict] = None, + output_size: int = 1, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.hidden_size = hidden_size + self.num_layers = num_layers + self.attention_head_size = attention_head_size + self.dropout = dropout + self.metadata = metadata + self.output_size = output_size + + self.max_encoder_length = self.metadata["max_encoder_length"] + self.max_prediction_length = self.metadata["max_prediction_length"] + self.encoder_cont = self.metadata["encoder_cont"] + self.encoder_cat = self.metadata["encoder_cat"] + self.static_categorical_features = self.metadata["static_categorical_features"] + self.static_continuous_features = self.metadata["static_continuous_features"] + + total_feature_size = self.encoder_cont + self.encoder_cat + total_static_size = ( + self.static_categorical_features + self.static_continuous_features + ) + + self.encoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.decoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.static_context_linear = ( + nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None + ) + + self.lstm_encoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.lstm_decoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.self_attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_head_size, + dropout=dropout, + batch_first=True, + ) + + self.pre_output = nn.Linear(hidden_size, hidden_size) + self.output_layer = nn.Linear(hidden_size, output_size) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the TFT model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors: + - encoder_cat: Categorical encoder features + - encoder_cont: Continuous encoder features + - decoder_cat: Categorical decoder features + - decoder_cont: Continuous decoder features + - static_categorical_features: Static categorical features + - static_continuous_features: Static continuous features + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors: + - prediction: Prediction output (batch_size, prediction_length, output_size) + """ + batch_size = x["encoder_cont"].shape[0] + + encoder_cat = x.get( + "encoder_cat", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + decoder_cat = x.get( + "decoder_cat", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + decoder_cont = x.get( + "decoder_cont", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + + encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2) + decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2) + + static_context = None + if self.static_context_linear is not None: + static_cat = x.get( + "static_categorical_features", + torch.zeros(batch_size, 0, device=self.device), + ) + static_cont = x.get( + "static_continuous_features", + torch.zeros(batch_size, 0, device=self.device), + ) + + if static_cat.size(2) == 0 and static_cont.size(2) == 0: + static_context = None + elif static_cat.size(2) == 0: + static_input = static_cont.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + elif static_cont.size(2) == 0: + static_input = static_cat.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + else: + + static_input = torch.cat([static_cont, static_cat], dim=1).to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + + if static_context is not None: + encoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_encoder_length, -1 + ) + decoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_prediction_length, -1 + ) + + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + encoder_output = encoder_output + encoder_static_context + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + decoder_output = decoder_output + decoder_static_context + else: + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + + sequence = torch.cat([encoder_output, decoder_output], dim=1) + + if static_context is not None: + expanded_static_context = static_context.unsqueeze(1).expand( + -1, sequence.size(1), -1 + ) + + attended_output, _ = self.self_attention( + sequence + expanded_static_context, sequence, sequence + ) + else: + attended_output, _ = self.self_attention(sequence, sequence, sequence) + + decoder_attended = attended_output[:, -self.max_prediction_length :, :] + + output = nn.functional.relu(self.pre_output(decoder_attended)) + prediction = self.output_layer(output) + + return {"prediction": prediction} From 6364780ae121298e3d98a2c14c6f6747bf62a7b4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:34:57 +0530 Subject: [PATCH 04/33] update docstring --- pytorch_forecasting/data/timeseries.py | 44 +++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index bc8300300..9da02d3a0 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2815,25 +2815,31 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: """Get time series data for given index. - It returns: - - * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with ``y``, - and ``x`` not ending in ``f``. - * ``y``: tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with ``t``. - * ``x``: tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with ``t``. - * ``group``: tensor of shape (n_groups) - Group identifiers for time series instances. - * ``st``: tensor of shape (n_static_features) - Static features. - * ``cutoff_time``: float or ``numpy.float64`` - Cutoff time for the time series instance. - - Optionally, the following str-keyed entry can be included: - - * ``weights``: tensor of shape (n_timepoints), only if weight is not None + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. """ group_id = self._group_ids[index] From 257183ce4d2b1f7fd40c95ecd7dc38c8004a017b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:54:50 +0530 Subject: [PATCH 05/33] update data_module.py --- pytorch_forecasting/data/data_module.py | 160 ++++++++++++------------ 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 2958f1705..c796b85fa 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,15 +1,6 @@ -####################################################################################### -# Disclaimer: This data-module is still work in progress and experimental, please -# use with care. This data-module is a basic skeleton of how the data-handling pipeline -# may look like in the future. -# This is D2 layer that will handle the preprocessing and data loaders. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule +from lightning.pytorch import LightningDataModule, LightningModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -19,7 +10,11 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict +from pytorch_forecasting.data.timeseries import ( + TimeSeries, + _coerce_to_dict, + _coerce_to_list, +) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] @@ -274,7 +269,7 @@ def metadata(self): self._metadata = self._prepare_metadata() return self._metadata - def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]: """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. Preprocessing steps @@ -288,63 +283,58 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: TODO: add scalers, target normalizers etc. """ - processed_data = [] + sample = self.time_series_dataset[series_idx] - for idx in indices: - sample = self.time_series_dataset[idx.item()] + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] - target = sample["y"] - features = sample["x"] - times = sample["t"] - cutoff_time = sample["cutoff_time"] + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - - if isinstance(target, torch.Tensor): - target = target.float() - else: - target = torch.tensor(target, dtype=torch.float32) - - if isinstance(features, torch.Tensor): - features = features.float() - else: - features = torch.tensor(features, dtype=torch.float32) + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) - # TODO: add scalers, target normalizers etc. + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) - categorical = ( - features[:, self.categorical_indices] - if self.categorical_indices - else torch.zeros((features.shape[0], 0)) - ) - continuous = ( - features[:, self.continuous_indices] - if self.continuous_indices - else torch.zeros((features.shape[0], 0)) - ) + # TODO: add scalers, target normalizers etc. - processed_data.append( - { - "features": {"categorical": categorical, "continuous": continuous}, - "target": target, - "static": sample.get("st", None), - "group": sample.get("group", torch.tensor([0])), - "length": len(target), - "time_mask": time_mask, - "times": times, - "cutoff_time": cutoff_time, - } - ) + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) - return processed_data + return { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } class _ProcessedEncoderDecoderDataset(Dataset): """PyTorch Dataset for processed encoder-decoder time series data. Parameters ---------- - processed_data : List[Dict[str, Any]] - List of preprocessed time series samples. + dataset : TimeSeries + The base time series dataset that provides access to raw data and metadata. + data_module : EncoderDecoderTimeSeriesDataModule + The data module handling preprocessing and metadata configuration. windows : List[Tuple[int, int, int, int]] List of window tuples containing (series_idx, start_idx, enc_length, pred_length). @@ -354,11 +344,13 @@ class _ProcessedEncoderDecoderDataset(Dataset): def __init__( self, - processed_data: List[Dict[str, Any]], + dataset: TimeSeries, + data_module: "EncoderDecoderTimeSeriesDataModule", windows: List[Tuple[int, int, int, int]], add_relative_time_idx: bool = False, ): - self.processed_data = processed_data + self.dataset = dataset + self.data_module = data_module self.windows = windows self.add_relative_time_idx = add_relative_time_idx @@ -410,7 +402,7 @@ def __getitem__(self, idx): Target values for the decoder sequence. """ series_idx, start_idx, enc_length, pred_length = self.windows[idx] - data = self.processed_data[series_idx] + data = self.data_module._preprocess_data(series_idx) end_idx = start_idx + enc_length + pred_length encoder_indices = slice(start_idx, start_idx + enc_length) @@ -457,9 +449,7 @@ def __getitem__(self, idx): return x, y - def _create_windows( - self, processed_data: List[Dict[str, Any]] - ) -> List[Tuple[int, int, int, int]]: + def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]: """Generate sliding windows for training, validation, and testing. Returns @@ -477,8 +467,10 @@ def _create_windows( """ windows = [] - for idx, data in enumerate(processed_data): - sequence_length = data["length"] + for idx in indices: + series_idx = idx.item() + sample = self.time_series_dataset[series_idx] + sequence_length = len(sample["y"]) if sequence_length < self.max_encoder_length + self.max_prediction_length: continue @@ -503,7 +495,7 @@ def _create_windows( ): windows.append( ( - idx, + series_idx, start_idx, self.max_encoder_length, self.max_prediction_length, @@ -538,33 +530,41 @@ def setup(self, stage: Optional[str] = None): if stage is None or stage == "fit": if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): - self.train_processed = self._preprocess_data(self._train_indices) - self.val_processed = self._preprocess_data(self._val_indices) - - self.train_windows = self._create_windows(self.train_processed) - self.val_windows = self._create_windows(self.val_processed) + self.train_windows = self._create_windows(self._train_indices) + self.val_windows = self._create_windows(self._val_indices) self.train_dataset = self._ProcessedEncoderDecoderDataset( - self.train_processed, self.train_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.train_windows, + self.add_relative_time_idx, ) self.val_dataset = self._ProcessedEncoderDecoderDataset( - self.val_processed, self.val_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.val_windows, + self.add_relative_time_idx, ) - elif stage is None or stage == "test": + elif stage == "test": if not hasattr(self, "test_dataset"): - self.test_processed = self._preprocess_data(self._test_indices) - self.test_windows = self._create_windows(self.test_processed) - + self.test_windows = self._create_windows(self._test_indices) self.test_dataset = self._ProcessedEncoderDecoderDataset( - self.test_processed, self.test_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.test_windows, + self, + self.add_relative_time_idx, ) elif stage == "predict": predict_indices = torch.arange(len(self.time_series_dataset)) - self.predict_processed = self._preprocess_data(predict_indices) - self.predict_windows = self._create_windows(self.predict_processed) + self.predict_windows = self._create_windows(predict_indices) self.predict_dataset = self._ProcessedEncoderDecoderDataset( - self.predict_processed, self.predict_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.predict_windows, + self, + self.add_relative_time_idx, ) def train_dataloader(self): From 9cdcb195c4c9e3f9b6d0e76ef3b6ed889bc14998 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:56:55 +0530 Subject: [PATCH 06/33] update data_module.py --- pytorch_forecasting/data/data_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index c796b85fa..9a4a5bf5e 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -553,7 +553,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.test_windows, - self, self.add_relative_time_idx, ) elif stage == "predict": @@ -563,7 +562,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.predict_windows, - self, self.add_relative_time_idx, ) From ac56d4fd56aeeb1287f162559c67e785de4446f4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 02:05:58 +0530 Subject: [PATCH 07/33] Add disclaimer --- pytorch_forecasting/data/data_module.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9a4a5bf5e..b33a11d47 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,6 +1,15 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule, LightningModule +from lightning.pytorch import LightningDataModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -13,7 +22,6 @@ from pytorch_forecasting.data.timeseries import ( TimeSeries, _coerce_to_dict, - _coerce_to_list, ) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] From 4bfff21de1a75be0c93dcb713cb91defe6bc2fad Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 12:44:44 +0530 Subject: [PATCH 08/33] update docstring --- pytorch_forecasting/data/data_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index b33a11d47..9d4e0b02f 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -465,7 +465,7 @@ def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, in List[Tuple[int, int, int, int]] A list of tuples, where each tuple consists of: - ``series_idx`` : int - Index of the time series in `processed_data`. + Index of the time series in `time_series_dataset`. - ``start_idx`` : int Start index of the encoder window. - ``enc_length`` : int @@ -522,7 +522,7 @@ def setup(self, stage: Optional[str] = None): - ``"fit"`` : Prepares training and validation datasets. - ``"test"`` : Prepares the test dataset. - ``"predict"`` : Prepares the dataset for inference. - - ``None`` : Prepares all datasets. + - ``None`` : Prepares ``fit`` datasets. """ total_series = len(self.time_series_dataset) self._split_indices = torch.randperm(total_series) From 8a53ed63933b0b92d752eaa707eadc7c45d35566 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 19 Apr 2025 19:37:40 +0530 Subject: [PATCH 09/33] Add tests for D1,D2 layer --- pytorch_forecasting/data/data_module.py | 56 ++- tests/test_data/test_d1.py | 379 +++++++++++++++++++ tests/test_data/test_data_module.py | 464 ++++++++++++++++++++++++ 3 files changed, 895 insertions(+), 4 deletions(-) create mode 100644 tests/test_data/test_d1.py create mode 100644 tests/test_data/test_data_module.py diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9d4e0b02f..1203e83ac 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -432,11 +432,59 @@ def __getitem__(self, idx): else torch.zeros(pred_length, dtype=torch.bool) ) + encoder_cat = data["features"]["categorical"][encoder_indices] + encoder_cont = data["features"]["continuous"][encoder_indices] + + features = data["features"] + metadata = self.data_module.time_series_metadata + + known_cat_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + + known_cont_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + + cat_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.categorical_indices) + } + cont_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.continuous_indices) + } + + mapped_known_cat_indices = [ + cat_map[idx] for idx in known_cat_indices if idx in cat_map + ] + mapped_known_cont_indices = [ + cont_map[idx] for idx in known_cont_indices if idx in cont_map + ] + + decoder_cat = ( + features["categorical"][decoder_indices][:, mapped_known_cat_indices] + if mapped_known_cat_indices + else torch.zeros((pred_length, 0)) + ) + + decoder_cont = ( + features["continuous"][decoder_indices][:, mapped_known_cont_indices] + if mapped_known_cont_indices + else torch.zeros((pred_length, 0)) + ) + x = { - "encoder_cat": data["features"]["categorical"][encoder_indices], - "encoder_cont": data["features"]["continuous"][encoder_indices], - "decoder_cat": data["features"]["categorical"][decoder_indices], - "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_cat": encoder_cat, + "encoder_cont": encoder_cont, + "decoder_cat": decoder_cat, + "decoder_cont": decoder_cont, "encoder_lengths": torch.tensor(enc_length), "decoder_lengths": torch.tensor(pred_length), "decoder_target_lengths": torch.tensor(pred_length), diff --git a/tests/test_data/test_d1.py b/tests/test_data/test_d1.py new file mode 100644 index 000000000..b32c13213 --- /dev/null +++ b/tests/test_data/test_d1.py @@ -0,0 +1,379 @@ +import numpy as np +import pandas as pd +import pytest +import torch + +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_data(): + """Create time series data for testing.""" + dates = pd.date_range(start="2023-01-01", periods=10, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "target_value": np.sin(np.arange(10)) + 10, + "feature1": np.random.randn(10), + "feature2": np.random.randn(10), + "feature3": np.random.randn(10), + "group_id": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "weight": np.abs(np.random.randn(10)) + 0.1, + "static_feat": [10, 10, 10, 10, 10, 20, 20, 20, 20, 20], + } + ) + return data + + +@pytest.fixture +def future_data(): + """Create future time series data.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 10, 20, 20], + } + ) + return data + + +def test_init_basic(sample_data): + """Test basic initialization of TimeSeries class. + + Ensures that the class stores time, target, and correctly detects feature columns + when no group, known/unknown features, or static/weight features are specified.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + assert ts.time == "timestamp" + assert ts.target == ["target_value"] + assert len(ts.feature_cols) == 6 # All columns except timestamp, target_value + assert len(ts) == 1 # Single group by default + + +def test_init_with_groups(sample_data): + """Test initialization with group parameter. + + Verifies that data is grouped correctly and each group is handled as a + separate time series. + """ + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert ts.group == ["group_id"] + assert len(ts) == 2 # Two groups (1 and 2) + assert set(ts._group_ids) == {1, 2} + + +def test_init_with_features_categorization(sample_data): + """Test feature categorization. + + Ensures that numeric, categorical, and static features are categorized and + stored correctly in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], + static=["static_feat"], + ) + + assert ts.num == ["feature1", "feature2", "feature3"] + assert ts.cat == [] + assert ts.static == ["static_feat"] + assert ts.metadata["col_type"]["feature1"] == "F" + assert ts.metadata["col_type"]["feature2"] == "F" + + +def test_init_with_known_unknown(sample_data): + """Test known and unknown features classification. + + Checks if the known and unknown feature categorization is correctly set + and stored in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + assert ts.known == ["feature1"] + assert ts.unknown == ["feature2", "feature3"] + assert ts.metadata["col_known"]["feature1"] == "K" + assert ts.metadata["col_known"]["feature2"] == "U" + + +def test_init_with_weight(sample_data): + """Test initialization with weight parameter. + + Verifies that the weight column is stored correctly and excluded + from the feature columns.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + assert ts.weight == "weight" + assert "weight" not in ts.feature_cols + + +def test_getitem_basic(sample_data): + """Test __getitem__ with basic configuration. + + Checks the output structure of a single time series without grouping, + ensuring x, y are tensors of correct shapes.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + result = ts[0] + assert torch.is_tensor(result["y"]) + assert torch.is_tensor(result["x"]) + assert "t" in result + assert "cutoff_time" in result + assert len(result["y"]) == 10 # 10 data points + assert result["y"].shape == (10, 1) # One target variable + assert result["x"].shape[1] == 6 # Six feature columns + + +def test_getitem_with_groups(sample_data): + """Test __getitem__ with groups parameter. + + Verifies the per-group access using index and checks that each group + has the correct number of time steps.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + # group (1) + result_g1 = ts[0] + assert len(result_g1["t"]) == 5 # 5 data points in group 1 + + # group (2) + result_g2 = ts[1] + assert len(result_g2["t"]) == 5 # 5 data points in group 2 + + +def test_getitem_with_static(sample_data): + """Test __getitem__ with static features. + + Ensures static features are included in the output and correctly + mapped per group.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + group=["group_id"], + static=["static_feat"], + ) + + result_g1 = ts[0] + result_g2 = ts[1] + + assert torch.is_tensor(result_g1["st"]) + assert result_g1["st"].item() == 10 # Static feature for group 1 + assert result_g2["st"].item() == 20 # Static feature for group 2 + + +def test_getitem_with_weight(sample_data): + """Test __getitem__ with weight parameter. + + Validates that weights are correctly returned in the output and have the + expected length and type.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + result = ts[0] + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == 10 + + +def test_with_future_data(sample_data, future_data): + """Test with future data provided. + + Verifies that future time steps are appended to the end of each group, + especially for known features.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result_g1 = ts[0] # Group 1 + + assert len(result_g1["t"]) == 8 # 5 original + 3 future for group 1 + + feature1_idx = ts.feature_cols.index("feature1") + assert not torch.isnan( + result_g1["x"][-1, feature1_idx] + ) # feature1 is not NaN in last row + + +def test_future_data_with_weights(sample_data, future_data): + """Test handling of weights with future data. + + Ensures that weights from future data are combined properly and match the + time indices.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + weight="weight", + ) + + result = ts[0] # Group 1 + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == len(result["t"]) + + +def test_future_data_missing_columns(sample_data): + """Test handling when future data is missing some columns. + + Verifies the handling of missing feature columns in future data by + checking NaN padding.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + incomplete_future = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + # Missing feature2, feature3 + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=incomplete_future, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result = ts[0] + # Check that missing features are NaN in future timepoints + future_indices = np.where(result["t"] >= np.datetime64("2023-01-11"))[0] + feature2_idx = ts.feature_cols.index("feature2") + feature3_idx = ts.feature_cols.index("feature3") + assert torch.isnan(result["x"][future_indices[0], feature2_idx]) + assert torch.isnan(result["x"][future_indices[0], feature3_idx]) + + +def test_different_future_groups(sample_data): + """Test with future data that has different groups than original data. + + Ensures that groups present only in future data are ignored if not + in the original dataset.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + future_with_new_group = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 3, 3, 3], # Group 3 is new + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 30, 30, 30], + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=future_with_new_group, + time="timestamp", + target="target_value", + group=["group_id"], + ) + + # Original data has groups 1 and 2, but not 3 + assert len(ts) == 2 + assert 3 not in ts._group_ids + + +def test_multiple_targets(sample_data): + """Test handling of multiple target variables. + + Verifies that multiple target columns are handled and returned + as the correct shape in the output.""" + sample_data["target_value2"] = np.cos(np.arange(10)) + 5 + + ts = TimeSeries( + data=sample_data, time="timestamp", target=["target_value", "target_value2"] + ) + + result = ts[0] + assert result["y"].shape == (10, 2) # Two target variables + + +def test_empty_groups(): + """Test handling of empty groups. + + Confirms that the class handles datasets with a single group and + no empty group errors occur.""" + data = pd.DataFrame( + { + "timestamp": pd.date_range(start="2023-01-01", periods=5, freq="D"), + "target_value": np.random.randn(5), + "group_id": [1, 1, 1, 1, 1], # Only one group + } + ) + + ts = TimeSeries( + data=data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert len(ts) == 1 # Only one group + + +def test_metadata_structure(sample_data): + """Test the structure of metadata. + + Ensures the metadata dictionary includes the expected keys and + correct mappings of feature roles.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], # No categorical features + static=["static_feat"], + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + metadata = ts.get_metadata() + + assert "cols" in metadata + assert "col_type" in metadata + assert "col_known" in metadata + + assert metadata["cols"]["y"] == ["target_value"] + assert set(metadata["cols"]["x"]) == { + "feature1", + "feature2", + "feature3", + "group_id", + "weight", + "static_feat", + } + assert metadata["cols"]["st"] == ["static_feat"] + + assert metadata["col_type"]["feature1"] == "F" + assert metadata["col_type"]["feature2"] == "F" + + assert metadata["col_known"]["feature1"] == "K" + assert metadata["col_known"]["feature2"] == "U" diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py new file mode 100644 index 000000000..c14e3d8f4 --- /dev/null +++ b/tests/test_data/test_data_module.py @@ -0,0 +1,464 @@ +import numpy as np +import pandas as pd +import pytest + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_timeseries_data(): + """Create a sample time series dataset with only numerical values.""" + num_groups = 5 + seq_length = 100 + + groups = [] + times = [] + values = [] + categorical_feature = [] + continuous_feature1 = [] + continuous_feature2 = [] + known_future = [] + + for g in range(num_groups): + for t in range(seq_length): + groups.append(g) + times.append(pd.Timestamp("2020-01-01") + pd.Timedelta(days=t)) + + value = 10 + 0.1 * t + 5 * np.sin(t / 10) + g * 2 + np.random.normal(0, 1) + values.append(value) + + categorical_feature.append(np.random.choice([0, 1, 2])) + + continuous_feature1.append(np.random.normal(g, 1)) + continuous_feature2.append(value * 0.5 + np.random.normal(0, 0.5)) + + known_future.append(t % 7) + + df = pd.DataFrame( + { + "group": groups, + "time": times, + "target": values, + "cat_feat": categorical_feature, + "cont_feat1": continuous_feature1, + "cont_feat2": continuous_feature2, + "known_future": known_future, + } + ) + + time_series = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["cont_feat1", "cont_feat2", "known_future"], + cat=["cat_feat"], + known=["known_future"], + ) + + return time_series + + +@pytest.fixture +def data_module(sample_timeseries_data): + """Create a data module instance.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.7, 0.15, 0.15), + ) + return dm + + +def test_init(sample_timeseries_data): + """Test the initialization of the data module. + + Verifies hyperparameter assignment and basic time_series_metadata creation.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=8, + ) + + assert dm.max_encoder_length == 24 + assert dm.max_prediction_length == 12 + assert dm.min_encoder_length == 24 + assert dm.min_prediction_length == 12 + assert dm.batch_size == 8 + assert dm.train_val_test_split == (0.7, 0.15, 0.15) + + assert isinstance(dm.time_series_metadata, dict) + assert "cols" in dm.time_series_metadata + + +def test_prepare_metadata(data_module): + """Test the metadata preparation method. + + Ensures that internal metadata keys are created correctly.""" + metadata = data_module._prepare_metadata() + + assert "encoder_cat" in metadata + assert "encoder_cont" in metadata + assert "decoder_cat" in metadata + assert "decoder_cont" in metadata + assert "target" in metadata + assert "max_encoder_length" in metadata + assert "max_prediction_length" in metadata + + assert metadata["max_encoder_length"] == 24 + assert metadata["max_prediction_length"] == 12 + + +def test_metadata_property(data_module): + """Test the metadata property. + + Confirms caching behavior and correct feature counts.""" + metadata = data_module.metadata + + # Should return the same object when called multiple times (caching) + assert data_module.metadata is metadata + + assert metadata["encoder_cat"] == 1 # cat_feat + assert metadata["encoder_cont"] == 3 # cont_feat1, cont_feat2, known_future + assert metadata["decoder_cat"] == 0 # No categorical features marked as known + assert metadata["decoder_cont"] == 1 # Only known_future marked as known + + +# def test_setup(data_module): +# """Test the setup method that prepares the datasets.""" +# data_module.setup(stage="fit") +# print(data_module._val_indices) +# assert hasattr(data_module, "train_dataset") +# assert hasattr(data_module, "val_dataset") +# assert len(data_module.train_windows) > 0 +# assert len(data_module.val_windows) > 0 +# +# data_module.setup(stage="test") +# assert hasattr(data_module, "test_dataset") +# assert len(data_module.test_windows) > 0 +# +# data_module.setup(stage="predict") +# assert hasattr(data_module, "predict_dataset") +# assert len(data_module.predict_windows) > 0 + + +def test_create_windows(data_module): + """Test the window creation logic. + + Validates window structure and length settings.""" + data_module.setup() + + windows = data_module._create_windows(data_module._train_indices) + + assert len(windows) > 0 + + for window in windows: + assert len(window) == 4 + assert window[2] == data_module.max_encoder_length + assert window[3] == data_module.max_prediction_length + + +def test_dataloader_creation(data_module): + """Test that dataloaders are created correctly. + + Checks batch sizes and dataloader instantiation across all stages.""" + data_module.setup() + + train_loader = data_module.train_dataloader() + assert train_loader.batch_size == data_module.batch_size + assert train_loader.num_workers == data_module.num_workers + + val_loader = data_module.val_dataloader() + assert val_loader.batch_size == data_module.batch_size + + data_module.setup(stage="test") + test_loader = data_module.test_dataloader() + assert test_loader.batch_size == data_module.batch_size + + data_module.setup(stage="predict") + predict_loader = data_module.predict_dataloader() + assert predict_loader.batch_size == data_module.batch_size + + +def test_processed_dataset(data_module): + """Test the internal ProcessedEncoderDecoderDataset class. + + Verifies sample structure and tensor dimensions for encoder/decoder inputs.""" + data_module.setup() + + assert len(data_module.train_dataset) == len(data_module.train_windows) + assert len(data_module.val_dataset) == len(data_module.val_windows) + + x, y = data_module.train_dataset[0] + + required_keys = [ + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "encoder_lengths", + "decoder_lengths", + "decoder_target_lengths", + "groups", + "encoder_time_idx", + "decoder_time_idx", + "target_scale", + "encoder_mask", + "decoder_mask", + ] + + for key in required_keys: + assert key in x + + assert x["encoder_cat"].shape[0] == data_module.max_encoder_length + assert x["decoder_cat"].shape[0] == data_module.max_prediction_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x["decoder_cat"].shape[1] == known_cat_count + assert x["decoder_cont"].shape[1] == known_cont_count + + assert y.shape[0] == data_module.max_prediction_length + + +def test_collate_fn(data_module): + """Test the collate function that combines batch samples. + + Ensures proper stacking of dictionary keys and batch outputs.""" + data_module.setup() + + batch_size = 3 + batch = [data_module.train_dataset[i] for i in range(batch_size)] + + x_batch, y_batch = data_module.collate_fn(batch) + + for key in x_batch: + assert x_batch[key].shape[0] == batch_size + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_full_dataloader_iteration(data_module): + """Test a full iteration through the train dataloader. + + Confirms batch retrieval and tensor dimensions match configuration.""" + data_module.setup() + train_loader = data_module.train_dataloader() + + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[0] == data_module.batch_size + assert x_batch["encoder_cat"].shape[1] == data_module.max_encoder_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[0] == data_module.batch_size + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[0] == data_module.batch_size + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == data_module.batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_variable_encoder_lengths(sample_timeseries_data): + """Test with variable encoder lengths. + + Ensures random length behavior is respected and functional.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + min_encoder_length=12, + max_prediction_length=12, + batch_size=4, + randomize_length=True, + ) + + dm.setup() + assert dm.min_encoder_length == 12 + assert dm.max_encoder_length == 24 + + +def test_preprocess_data(data_module, sample_timeseries_data): + """Test the _preprocess_data method. + + Checks preprocessing output structure and alignment with raw data.""" + if not hasattr(data_module, "_split_indices"): + data_module.setup() + + series_idx = data_module._train_indices[0] + + processed = data_module._preprocess_data(series_idx) + + assert "features" in processed + assert "categorical" in processed["features"] + assert "continuous" in processed["features"] + assert "target" in processed + assert "time_mask" in processed + + original_sample = sample_timeseries_data[series_idx.item()] + expected_length = len(original_sample["y"]) + + assert processed["features"]["categorical"].shape[0] == expected_length + assert processed["features"]["continuous"].shape[0] == expected_length + assert processed["target"].shape[0] == expected_length + + +def test_with_static_features(): + """Test with static features included. + + Validates static feature support in both metadata and sample input.""" + df = pd.DataFrame( + { + "group": [0, 0, 0, 1, 1, 1], + "time": pd.date_range("2020-01-01", periods=6), + "target": [1, 2, 3, 4, 5, 6], + "static_cat": [0, 0, 0, 1, 1, 1], + "static_num": [10, 10, 10, 20, 20, 20], + "feature1": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + } + ) + + ts = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["feature1", "static_num"], + static=["static_cat", "static_num"], + cat=["static_cat"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=2, + max_prediction_length=1, + batch_size=2, + ) + + dm.setup() + + metadata = dm.metadata + assert metadata["static_categorical_features"] == 1 + assert metadata["static_continuous_features"] == 1 + + x, y = dm.train_dataset[0] + assert "static_categorical_features" in x + assert "static_continuous_features" in x + + +# def test_different_train_val_test_split(sample_timeseries_data): +# """Test with different train/val/test split ratios.""" +# dm = EncoderDecoderTimeSeriesDataModule( +# time_series_dataset=sample_timeseries_data, +# max_encoder_length=24, +# max_prediction_length=12, +# batch_size=4, +# train_val_test_split=(0.8, 0.1, 0.1), +# ) +# +# dm.setup() +# +# total_series = len(sample_timeseries_data) +# expected_train = int(0.8 * total_series) +# expected_val = int(0.1 * total_series) +# +# assert len(dm._train_indices) == expected_train +# assert len(dm._val_indices) == expected_val +# assert len(dm._test_indices) == total_series - expected_train - expected_val + + +def test_multivariate_target(): + """Test with multivariate target (multiple target columns). + + Verifies correct handling of multivariate targets in data pipeline.""" + df = pd.DataFrame( + { + "group": np.repeat([0, 1], 50), + "time": np.tile(pd.date_range("2020-01-01", periods=50), 2), + "target1": np.random.normal(0, 1, 100), + "target2": np.random.normal(5, 2, 100), + "feature1": np.random.normal(0, 1, 100), + "feature2": np.random.normal(0, 1, 100), + } + ) + + ts = TimeSeries( + data=df, + time="time", + target=["target1", "target2"], + group=["group"], + num=["feature1", "feature2"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=10, + max_prediction_length=5, + batch_size=4, + ) + + dm.setup() + + x, y = dm.train_dataset[0] + assert y.shape[-1] == 2 From cdecb770a63269c965261cee3a54744449b445a4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 19 Apr 2025 19:44:16 +0530 Subject: [PATCH 10/33] Code quality --- pytorch_forecasting/data/timeseries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 6b2662e95..fda08d561 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -9,7 +9,7 @@ from copy import copy as _copy, deepcopy from functools import lru_cache import inspect -from typing import Any, Callable, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union import warnings import numpy as np From 20aafb749cfebdb1f9789b4dff5120fa8527db74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:40:01 +0200 Subject: [PATCH 11/33] refactor file --- pytorch_forecasting/data/__init__.py | 3 +- .../data/timeseries/__init__.py | 9 + .../data/timeseries/_coerce.py | 25 ++ .../_timeseries.py} | 286 +----------------- .../data/timeseries/_timeseries_v2.py | 276 +++++++++++++++++ 5 files changed, 314 insertions(+), 285 deletions(-) create mode 100644 pytorch_forecasting/data/timeseries/__init__.py create mode 100644 pytorch_forecasting/data/timeseries/_coerce.py rename pytorch_forecasting/data/{timeseries.py => timeseries/_timeseries.py} (90%) create mode 100644 pytorch_forecasting/data/timeseries/_timeseries_v2.py diff --git a/pytorch_forecasting/data/__init__.py b/pytorch_forecasting/data/__init__.py index 301c8394d..17be285a0 100644 --- a/pytorch_forecasting/data/__init__.py +++ b/pytorch_forecasting/data/__init__.py @@ -13,10 +13,11 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet __all__ = [ "TimeSeriesDataSet", + "TimeSeries", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py new file mode 100644 index 000000000..7734cccf2 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -0,0 +1,9 @@ +"""Data loaders for time series data.""" + +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries +from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet + +__all__ = [ + "TimeSeriesDataSet", + "TimeSeries", +] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/data/timeseries/_coerce.py new file mode 100644 index 000000000..328431aa8 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_coerce.py @@ -0,0 +1,25 @@ +"""Coercion functions for various data types.""" + +from copy import deepcopy + + +def _coerce_to_list(obj): + """Coerce object to list. + + None is coerced to empty list, otherwise list constructor is used. + """ + if obj is None: + return [] + if isinstance(obj, str): + return [obj] + return list(obj) + + +def _coerce_to_dict(obj): + """Coerce object to dict. + + None is coerce to empty dict, otherwise deepcopy is used. + """ + if obj is None: + return {} + return deepcopy(obj) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py similarity index 90% rename from pytorch_forecasting/data/timeseries.py rename to pytorch_forecasting/data/timeseries/_timeseries.py index fda08d561..263e0ea3a 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -9,7 +9,7 @@ from copy import copy as _copy, deepcopy from functools import lru_cache import inspect -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Optional, Type, TypeVar, Union import warnings import numpy as np @@ -31,6 +31,7 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class from pytorch_forecasting.utils._dependencies import _check_matplotlib @@ -2663,286 +2664,3 @@ def __repr__(self) -> str: attributes=self.get_parameters(), extra_attributes=dict(length=len(self)), ) - - -def _coerce_to_list(obj): - """Coerce object to list. - - None is coerced to empty list, otherwise list constructor is used. - """ - if obj is None: - return [] - if isinstance(obj, str): - return [obj] - return list(obj) - - -def _coerce_to_dict(obj): - """Coerce object to dict. - - None is coerce to empty dict, otherwise deepcopy is used. - """ - if obj is None: - return {} - return deepcopy(obj) - - -####################################################################################### -# Disclaimer: This dataset class is still work in progress and experimental, please -# use with care. This class is a basic skeleton of how the data-handling pipeline may -# look like in the future. -# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion -# and turning the data to tensors. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - - -class TimeSeries(Dataset): - """PyTorch Dataset for time series data stored in pandas DataFrame. - - Parameters - ---------- - data : pd.DataFrame - data frame with sequence data. - Column names must all be str, and contain str as referred to below. - data_future : pd.DataFrame, optional, default=None - data frame with future data. - Column names must all be str, and contain str as referred to below. - May contain only columns that are in time, group, weight, known, or static. - time : str, optional, default = first col not in group_ids, weight, target, static. - integer typed column denoting the time index within ``data``. - This column is used to determine the sequence of samples. - If there are no missing observations, - the time index should increase by ``+1`` for each subsequent sample. - The first time_idx for each series does not necessarily - have to be ``0`` but any value is allowed. - target : str or List[str], optional, default = last column (at iloc -1) - column(s) in ``data`` denoting the forecasting target. - Can be categorical or numerical dtype. - group : List[str], optional, default = None - list of column names identifying a time series instance within ``data``. - This means that the ``group`` together uniquely identify an instance, - and ``group`` together with ``time`` uniquely identify a single observation - within a time series instance. - If ``None``, the dataset is assumed to be a single time series. - weight : str, optional, default=None - column name for weights. - If ``None``, it is assumed that there is no weight column. - num : list of str, optional, default = all columns with dtype in "fi" - list of numerical variables in ``data``, - list may also contain list of str, which are then grouped together. - cat : list of str, optional, default = all columns with dtype in "Obc" - list of categorical variables in ``data``, - list may also contain list of str, which are then grouped together - (e.g. useful for product categories). - known : list of str, optional, default = all variables - list of variables that change over time and are known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for special days or promotion categories). - unknown : list of str, optional, default = no variables - list of variables that are not known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for weather categories). - static : list of str, optional, default = all variables not in known, unknown - list of variables that do not change over time, - list may also contain list of str, which are then grouped together. - """ - - def __init__( - self, - data: pd.DataFrame, - data_future: Optional[pd.DataFrame] = None, - time: Optional[str] = None, - target: Optional[Union[str, List[str]]] = None, - group: Optional[List[str]] = None, - weight: Optional[str] = None, - num: Optional[List[Union[str, List[str]]]] = None, - cat: Optional[List[Union[str, List[str]]]] = None, - known: Optional[List[Union[str, List[str]]]] = None, - unknown: Optional[List[Union[str, List[str]]]] = None, - static: Optional[List[Union[str, List[str]]]] = None, - ): - - self.data = data - self.data_future = data_future - self.time = time - self.target = _coerce_to_list(target) - self.group = _coerce_to_list(group) - self.weight = weight - self.num = _coerce_to_list(num) - self.cat = _coerce_to_list(cat) - self.known = _coerce_to_list(known) - self.unknown = _coerce_to_list(unknown) - self.static = _coerce_to_list(static) - - self.feature_cols = [ - col - for col in data.columns - if col not in [self.time] + self.group + [self.weight] + self.target - ] - if self.group: - self._groups = self.data.groupby(self.group).groups - self._group_ids = list(self._groups.keys()) - else: - self._groups = {"_single_group": self.data.index} - self._group_ids = ["_single_group"] - - self._prepare_metadata() - - def _prepare_metadata(self): - """Prepare metadata for the dataset. - - The funcion returns metadata that contains: - - * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } - Names of columns for y, x, and static features. - List elements are in same order as column dimensions. - Columns not appearing are assumed to be named (x0, x1, etc.), - (y0, y1, etc.), (st0, st1, etc.). - * ``col_type``: dict[str, str] - maps column names to data types "F" (numerical) and "C" (categorical). - Column names not occurring are assumed "F". - * ``col_known``: dict[str, str] - maps column names to "K" (future known) or "U" (future unknown). - Column names not occurring are assumed "K". - """ - self.metadata = { - "cols": { - "y": self.target, - "x": self.feature_cols, - "st": self.static, - }, - "col_type": {}, - "col_known": {}, - } - - all_cols = self.target + self.feature_cols + self.static - for col in all_cols: - self.metadata["col_type"][col] = "C" if col in self.cat else "F" - - self.metadata["col_known"][col] = "K" if col in self.known else "U" - - def __len__(self) -> int: - """Return number of time series in the dataset.""" - return len(self._group_ids) - - def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: - """Get time series data for given index. - - Returns - ------- - t : numpy.ndarray of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with `y`, - and `x` not ending in `f`. - - y : torch.Tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with `t`. - - x : torch.Tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with `t`. - - group : torch.Tensor of shape (n_groups,) - Group identifiers for time series instances. - - st : torch.Tensor of shape (n_static_features,) - Static features. - - cutoff_time : float or numpy.float64 - Cutoff time for the time series instance. - - Other Returns - ------------- - weights : torch.Tensor of shape (n_timepoints,), optional - Only included if weights are not `None`. - """ - group_id = self._group_ids[index] - - if self.group: - mask = self._groups[group_id] - data = self.data.loc[mask] - else: - data = self.data - - cutoff_time = data[self.time].max() - - result = { - "t": data[self.time].values, - "y": torch.tensor(data[self.target].values), - "x": torch.tensor(data[self.feature_cols].values), - "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), - "cutoff_time": cutoff_time, - } - - if self.data_future is not None: - if self.group: - future_mask = self.data_future.groupby(self.group).groups[group_id] - future_data = self.data_future.loc[future_mask] - else: - future_data = self.data_future - - combined_times = np.concatenate( - [data[self.time].values, future_data[self.time].values] - ) - combined_times = np.unique(combined_times) - combined_times.sort() - - num_timepoints = len(combined_times) - x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self.target)), np.nan) - - current_time_indices = {t: i for i, t in enumerate(combined_times)} - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self.target].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices: - idx = current_time_indices[t] - for j, col in enumerate(self.known): - if col in self.feature_cols: - feature_idx = self.feature_cols.index(col) - x_merged[idx, feature_idx] = future_data[col].values[i] - - result.update( - { - "t": combined_times, - "x": torch.tensor(x_merged, dtype=torch.float32), - "y": torch.tensor(y_merged, dtype=torch.float32), - } - ) - - if self.weight: - if self.data_future is not None and self.weight in self.data_future.columns: - weights_merged = np.full(num_timepoints, np.nan) - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - weights_merged[idx] = data[self.weight].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices and self.weight in future_data.columns: - idx = current_time_indices[t] - weights_merged[idx] = future_data[self.weight].values[i] - - result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) - else: - result["weights"] = torch.tensor( - data[self.weight].values, dtype=torch.float32 - ) - - return result - - def get_metadata(self) -> Dict: - """Return metadata about the dataset. - - Returns - ------- - Dict - Dictionary containing: - - cols: column names for y, x, and static features - - col_type: mapping of columns to their types (F/C) - - col_known: mapping of columns to their future known status (K/U) - """ - return self.metadata diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py new file mode 100644 index 000000000..53bf7228d --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -0,0 +1,276 @@ +""" +Timeseries dataset - v2 prototype. + +Beta version, experimental - use for testing but not in production. +""" + +from typing import Dict, List, Optional, Union +import warnings + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From 043820dd3be3041a019fd9cd2cb1e681d25a79a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:43:50 +0200 Subject: [PATCH 12/33] warning --- .../data/timeseries/_timeseries_v2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 53bf7228d..1c91d2525 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -104,6 +104,18 @@ def __init__( self.unknown = _coerce_to_list(unknown) self.static = _coerce_to_list(static) + warnings.warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + self.feature_cols = [ col for col in data.columns From 1720a15e9cff3e5c3ebcd0bf3ec03995d068e4b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 13:58:09 +0200 Subject: [PATCH 13/33] linting --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 7734cccf2..85973267a 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,7 +1,7 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ "TimeSeriesDataSet", diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 1c91d2525..76972ab4d 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -14,7 +14,6 @@ from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list - ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please # use with care. This class is a basic skeleton of how the data-handling pipeline may From af44474d16b3fcdf5e99acb4b9d1f7345119d8cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:21:58 +0200 Subject: [PATCH 14/33] move coercion to utils --- pytorch_forecasting/data/data_module.py | 6 ++---- pytorch_forecasting/{data/timeseries => utils}/_coerce.py | 0 2 files changed, 2 insertions(+), 4 deletions(-) rename pytorch_forecasting/{data/timeseries => utils}/_coerce.py (100%) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 1203e83ac..9d3ebbedb 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -19,10 +19,8 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import ( - TimeSeries, - _coerce_to_dict, -) +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.utils._coerce import _coerce_to_dict NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/utils/_coerce.py similarity index 100% rename from pytorch_forecasting/data/timeseries/_coerce.py rename to pytorch_forecasting/utils/_coerce.py From a3cb8b736b0b134c8faa97f5ef2993deb28fb75b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:22:18 +0200 Subject: [PATCH 15/33] linting --- pytorch_forecasting/data/timeseries/_timeseries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 263e0ea3a..30fe9e0bb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -31,8 +31,8 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class +from pytorch_forecasting.utils._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils._dependencies import _check_matplotlib From 75d7fb54d8405ef493197c5a4d2fc86a5e9e9d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:25:51 +0200 Subject: [PATCH 16/33] Update _timeseries_v2.py --- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 76972ab4d..afa45725b 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -12,7 +12,7 @@ import torch from torch.utils.data import Dataset -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list +from pytorch_forecasting.utils._coerce import _coerce_to_list ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please From 1b946e699be9db2e201a2361779a695356a0460b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:30:13 +0200 Subject: [PATCH 17/33] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 85973267a..b359a0aa9 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,15 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries import ( + _find_end_indices, + check_for_nonfinite, + TimeSeriesDataSet, +) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ + "_find_end_indices", + "check_for_nonfinite", "TimeSeriesDataSet", "TimeSeries", ] From 3edb08b7ea1b97d06b47b0ebcc83aaef9bec8083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:33:17 +0200 Subject: [PATCH 18/33] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index b359a0aa9..788c08201 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,9 @@ """Data loaders for time series data.""" from pytorch_forecasting.data.timeseries._timeseries import ( + TimeSeriesDataSet, _find_end_indices, check_for_nonfinite, - TimeSeriesDataSet, ) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries From e350291c110f567e69946e0e113f2471b7472738 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 11 May 2025 22:10:01 +0530 Subject: [PATCH 19/33] update tests --- tests/test_data/test_data_module.py | 72 ++++++++++++++--------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index c14e3d8f4..4051b852c 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -9,7 +9,7 @@ @pytest.fixture def sample_timeseries_data(): """Create a sample time series dataset with only numerical values.""" - num_groups = 5 + num_groups = 10 seq_length = 100 groups = [] @@ -128,22 +128,22 @@ def test_metadata_property(data_module): assert metadata["decoder_cont"] == 1 # Only known_future marked as known -# def test_setup(data_module): -# """Test the setup method that prepares the datasets.""" -# data_module.setup(stage="fit") -# print(data_module._val_indices) -# assert hasattr(data_module, "train_dataset") -# assert hasattr(data_module, "val_dataset") -# assert len(data_module.train_windows) > 0 -# assert len(data_module.val_windows) > 0 -# -# data_module.setup(stage="test") -# assert hasattr(data_module, "test_dataset") -# assert len(data_module.test_windows) > 0 -# -# data_module.setup(stage="predict") -# assert hasattr(data_module, "predict_dataset") -# assert len(data_module.predict_windows) > 0 +def test_setup(data_module): + """Test the setup method that prepares the datasets.""" + data_module.setup(stage="fit") + print(data_module._val_indices) + assert hasattr(data_module, "train_dataset") + assert hasattr(data_module, "val_dataset") + assert len(data_module.train_windows) > 0 + assert len(data_module.val_windows) > 0 + + data_module.setup(stage="test") + assert hasattr(data_module, "test_dataset") + assert len(data_module.test_windows) > 0 + + data_module.setup(stage="predict") + assert hasattr(data_module, "predict_dataset") + assert len(data_module.predict_windows) > 0 def test_create_windows(data_module): @@ -407,25 +407,25 @@ def test_with_static_features(): assert "static_continuous_features" in x -# def test_different_train_val_test_split(sample_timeseries_data): -# """Test with different train/val/test split ratios.""" -# dm = EncoderDecoderTimeSeriesDataModule( -# time_series_dataset=sample_timeseries_data, -# max_encoder_length=24, -# max_prediction_length=12, -# batch_size=4, -# train_val_test_split=(0.8, 0.1, 0.1), -# ) -# -# dm.setup() -# -# total_series = len(sample_timeseries_data) -# expected_train = int(0.8 * total_series) -# expected_val = int(0.1 * total_series) -# -# assert len(dm._train_indices) == expected_train -# assert len(dm._val_indices) == expected_val -# assert len(dm._test_indices) == total_series - expected_train - expected_val +def test_different_train_val_test_split(sample_timeseries_data): + """Test with different train/val/test split ratios.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.8, 0.1, 0.1), + ) + + dm.setup() + + total_series = len(sample_timeseries_data) + expected_train = int(0.8 * total_series) + expected_val = int(0.1 * total_series) + + assert len(dm._train_indices) == expected_train + assert len(dm._val_indices) == expected_val + assert len(dm._test_indices) == total_series - expected_train - expected_val def test_multivariate_target(): From 3099691d3cc792bd528f50ff3c51a0fa4a9ce28a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 12 May 2025 00:22:27 +0530 Subject: [PATCH 20/33] update tft_v2 --- .../tft_version_two.py | 65 +++++++++++-------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 30f70f98e..2bfe407d7 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -36,6 +36,8 @@ def __init__( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, ) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + self.hidden_size = hidden_size self.num_layers = num_layers self.attention_head_size = attention_head_size @@ -47,42 +49,51 @@ def __init__( self.max_prediction_length = self.metadata["max_prediction_length"] self.encoder_cont = self.metadata["encoder_cont"] self.encoder_cat = self.metadata["encoder_cat"] - self.static_categorical_features = self.metadata["static_categorical_features"] - self.static_continuous_features = self.metadata["static_continuous_features"] - - total_feature_size = self.encoder_cont + self.encoder_cat - total_static_size = ( - self.static_categorical_features + self.static_continuous_features - ) - - self.encoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) - - self.decoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) + self.encoder_input_dim = self.encoder_cont + self.encoder_cat + self.decoder_cont = self.metadata["decoder_cont"] + self.decoder_cat = self.metadata["decoder_cat"] + self.decoder_input_dim = self.decoder_cont + self.decoder_cat + self.static_cat_dim = self.metadata.get("static_categorical_features", 0) + self.static_cont_dim = self.metadata.get("static_continuous_features", 0) + self.static_input_dim = self.static_cat_dim + self.static_cont_dim + + if self.encoder_input_dim > 0: + self.encoder_var_selection = nn.Sequential( + nn.Linear(self.encoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.encoder_input_dim), + nn.Sigmoid(), + ) + else: + self.encoder_var_selection = None + + if self.decoder_input_dim > 0: + self.decoder_var_selection = nn.Sequential( + nn.Linear(self.decoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.decoder_input_dim), + nn.Sigmoid(), + ) + else: + self.decoder_var_selection = None - self.static_context_linear = ( - nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None - ) + if self.static_input_dim > 0: + self.static_context_linear = nn.Linear(self.static_input_dim, hidden_size) + else: + self.static_context_linear = None + _lstm_encoder_input_actual_dim = self.encoder_input_dim self.lstm_encoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_encoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True, ) + _lstm_decoder_input_actual_dim = self.decoder_input_dim self.lstm_decoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_decoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, @@ -97,7 +108,7 @@ def __init__( ) self.pre_output = nn.Linear(hidden_size, hidden_size) - self.output_layer = nn.Linear(hidden_size, output_size) + self.output_layer = nn.Linear(hidden_size, self.output_size) def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ From 77cb979808d83cbcfb4e7c3ed5ffd888c0828d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:14:03 +0200 Subject: [PATCH 21/33] warnings and init attr handling --- pytorch_forecasting/data/data_module.py | 44 +++++++++---- .../data/timeseries/_timeseries_v2.py | 61 +++++++++++-------- 2 files changed, 67 insertions(+), 38 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9d3ebbedb..690fb6057 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -8,6 +8,7 @@ ####################################################################################### from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn from lightning.pytorch import LightningDataModule from sklearn.preprocessing import RobustScaler, StandardScaler @@ -107,33 +108,50 @@ def __init__( num_workers: int = 0, train_val_test_split: tuple = (0.7, 0.15, 0.15), ): - super().__init__() - self.time_series_dataset = time_series_dataset - self.time_series_metadata = time_series_dataset.get_metadata() + self.time_series_dataset = time_series_dataset self.max_encoder_length = max_encoder_length - self.min_encoder_length = min_encoder_length or max_encoder_length + self.min_encoder_length = min_encoder_length self.max_prediction_length = max_prediction_length - self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_length = min_prediction_length self.min_prediction_idx = min_prediction_idx - self.allow_missing_timesteps = allow_missing_timesteps self.add_relative_time_idx = add_relative_time_idx self.add_target_scales = add_target_scales self.add_encoder_length = add_encoder_length self.randomize_length = randomize_length - + self.target_normalizer = target_normalizer + self.categorical_encoders = categorical_encoders + self.scalers = scalers self.batch_size = batch_size self.num_workers = num_workers self.train_val_test_split = train_val_test_split + warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + super().__init__() + + # handle defaults and derived attributes if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": - self.target_normalizer = RobustScaler() + self._target_normalizer = RobustScaler() else: - self.target_normalizer = target_normalizer + self._target_normalizer = target_normalizer - self.categorical_encoders = _coerce_to_dict(categorical_encoders) - self.scalers = _coerce_to_dict(scalers) + self.time_series_metadata = time_series_dataset.get_metadata() + self._min_prediction_length = min_prediction_length or max_prediction_length + self._min_encoder_length = min_encoder_length or max_encoder_length + self._categorical_encoders = _coerce_to_dict(categorical_encoders) + self._scalers = _coerce_to_dict(scalers) self.categorical_indices = [] self.continuous_indices = [] @@ -237,8 +255,8 @@ def _prepare_metadata(self): { "max_encoder_length": self.max_encoder_length, "max_prediction_length": self.max_prediction_length, - "min_encoder_length": self.min_encoder_length, - "min_prediction_length": self.min_prediction_length, + "min_encoder_length": self._min_encoder_length, + "min_prediction_length": self._min_prediction_length, } ) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index afa45725b..1f0ba6820 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -5,7 +5,7 @@ """ from typing import Dict, List, Optional, Union -import warnings +from warnings import warn import numpy as np import pandas as pd @@ -94,16 +94,16 @@ def __init__( self.data = data self.data_future = data_future self.time = time - self.target = _coerce_to_list(target) - self.group = _coerce_to_list(group) + self.target = target + self.group = group self.weight = weight - self.num = _coerce_to_list(num) - self.cat = _coerce_to_list(cat) - self.known = _coerce_to_list(known) - self.unknown = _coerce_to_list(unknown) - self.static = _coerce_to_list(static) + self.num = num + self.cat = cat + self.known = known + self.unknown = unknown + self.static = static - warnings.warn( + warn( "TimeSeries is part of an experimental rework of the " "pytorch-forecasting data layer, " "scheduled for release with v2.0.0. " @@ -115,13 +115,24 @@ def __init__( UserWarning, ) + super.__init__() + + # handle defaults, coercion, and derived attributes + self._target = _coerce_to_list(target) + self._group = _coerce_to_list(group) + self._num = _coerce_to_list(num) + self._cat = _coerce_to_list(cat) + self._known = _coerce_to_list(known) + self._unknown = _coerce_to_list(unknown) + self._static = _coerce_to_list(static) + self.feature_cols = [ col for col in data.columns - if col not in [self.time] + self.group + [self.weight] + self.target + if col not in [self.time] + self._group + [self.weight] + self._target ] - if self.group: - self._groups = self.data.groupby(self.group).groups + if self._group: + self._groups = self.data.groupby(self._group).groups self._group_ids = list(self._groups.keys()) else: self._groups = {"_single_group": self.data.index} @@ -148,19 +159,19 @@ def _prepare_metadata(self): """ self.metadata = { "cols": { - "y": self.target, + "y": self._target, "x": self.feature_cols, - "st": self.static, + "st": self._static, }, "col_type": {}, "col_known": {}, } - all_cols = self.target + self.feature_cols + self.static + all_cols = self._target + self.feature_cols + self._static for col in all_cols: - self.metadata["col_type"][col] = "C" if col in self.cat else "F" + self.metadata["col_type"][col] = "C" if col in self._cat else "F" - self.metadata["col_known"][col] = "K" if col in self.known else "U" + self.metadata["col_known"][col] = "K" if col in self._known else "U" def __len__(self) -> int: """Return number of time series in the dataset.""" @@ -197,7 +208,7 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: """ group_id = self._group_ids[index] - if self.group: + if self._group: mask = self._groups[group_id] data = self.data.loc[mask] else: @@ -207,16 +218,16 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: result = { "t": data[self.time].values, - "y": torch.tensor(data[self.target].values), + "y": torch.tensor(data[self._target].values), "x": torch.tensor(data[self.feature_cols].values), "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "st": torch.tensor(data[self._static].iloc[0].values if self._static else []), "cutoff_time": cutoff_time, } if self.data_future is not None: - if self.group: - future_mask = self.data_future.groupby(self.group).groups[group_id] + if self._group: + future_mask = self.data_future.groupby(self._group).groups[group_id] future_data = self.data_future.loc[future_mask] else: future_data = self.data_future @@ -229,18 +240,18 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: num_timepoints = len(combined_times) x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self.target)), np.nan) + y_merged = np.full((num_timepoints, len(self._target)), np.nan) current_time_indices = {t: i for i, t in enumerate(combined_times)} for i, t in enumerate(data[self.time].values): idx = current_time_indices[t] x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self.target].values[i] + y_merged[idx] = data[self._target].values[i] for i, t in enumerate(future_data[self.time].values): if t in current_time_indices: idx = current_time_indices[t] - for j, col in enumerate(self.known): + for j, col in enumerate(self._known): if col in self.feature_cols: feature_idx = self.feature_cols.index(col) x_merged[idx, feature_idx] = future_data[col].values[i] From f8c94e626010d165cf022e0fd3f0a22c994759c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:25:53 +0200 Subject: [PATCH 22/33] simplify TimeSeries.__getitem__ --- .../data/timeseries/_timeseries_v2.py | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 1f0ba6820..5e24f6454 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -206,54 +206,69 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: weights : torch.Tensor of shape (n_timepoints,), optional Only included if weights are not `None`. """ - group_id = self._group_ids[index] - - if self._group: - mask = self._groups[group_id] + time = self.time + feature_cols = self.feature_cols + _target = self._target + _known = self._known + _static = self._static + _group = self._group + _groups = self._groups + _group_ids = self._group_ids + weight = self.weight + data_future = self.data_future + + group_id = _group_ids[index] + + if _group: + mask = _groups[group_id] data = self.data.loc[mask] else: data = self.data - cutoff_time = data[self.time].max() + cutoff_time = data[time].max() + + data_vals = data[time].values + data_tgt_vals = data[_target].values + data_feat_vals = data[feature_cols].values result = { - "t": data[self.time].values, - "y": torch.tensor(data[self._target].values), - "x": torch.tensor(data[self.feature_cols].values), + "t": data_vals, + "y": torch.tensor(data_tgt_vals), + "x": torch.tensor(data_feat_vals), "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self._static].iloc[0].values if self._static else []), + "st": torch.tensor(data[_static].iloc[0].values if _static else []), "cutoff_time": cutoff_time, } - if self.data_future is not None: - if self._group: - future_mask = self.data_future.groupby(self._group).groups[group_id] + if data_future is not None: + if _group: + future_mask = self.data_future.groupby(_group).groups[group_id] future_data = self.data_future.loc[future_mask] else: future_data = self.data_future - combined_times = np.concatenate( - [data[self.time].values, future_data[self.time].values] - ) + data_fut_vals = future_data[time].values + + combined_times = np.concatenate([data_vals, data_fut_vals]) combined_times = np.unique(combined_times) combined_times.sort() num_timepoints = len(combined_times) - x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self._target)), np.nan) + x_merged = np.full((num_timepoints, len(feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(_target)), np.nan) current_time_indices = {t: i for i, t in enumerate(combined_times)} - for i, t in enumerate(data[self.time].values): + for i, t in enumerate(data_vals): idx = current_time_indices[t] - x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self._target].values[i] + x_merged[idx] = data_feat_vals[i] + y_merged[idx] = data_tgt_vals[i] - for i, t in enumerate(future_data[self.time].values): + for i, t in enumerate(data_fut_vals): if t in current_time_indices: idx = current_time_indices[t] - for j, col in enumerate(self._known): - if col in self.feature_cols: - feature_idx = self.feature_cols.index(col) + for j, col in enumerate(_known): + if col in feature_cols: + feature_idx = feature_cols.index(col) x_merged[idx, feature_idx] = future_data[col].values[i] result.update( @@ -264,17 +279,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: } ) - if self.weight: + if weight: if self.data_future is not None and self.weight in self.data_future.columns: weights_merged = np.full(num_timepoints, np.nan) - for i, t in enumerate(data[self.time].values): + for i, t in enumerate(data_vals): idx = current_time_indices[t] - weights_merged[idx] = data[self.weight].values[i] + weights_merged[idx] = data[weight].values[i] - for i, t in enumerate(future_data[self.time].values): + for i, t in enumerate(data_fut_vals): if t in current_time_indices and self.weight in future_data.columns: idx = current_time_indices[t] - weights_merged[idx] = future_data[self.weight].values[i] + weights_merged[idx] = future_data[weight].values[i] result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) else: From c289255286540b96ddcf5667851f06edf7af0c7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:36:17 +0200 Subject: [PATCH 23/33] Update _timeseries_v2.py --- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 5e24f6454..178b273bc 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -115,7 +115,7 @@ def __init__( UserWarning, ) - super.__init__() + super().__init__() # handle defaults, coercion, and derived attributes self._target = _coerce_to_list(target) From 9467f387287f3ba4a56ef1a1a4673c2215deb355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:44:38 +0200 Subject: [PATCH 24/33] Update data_module.py --- pytorch_forecasting/data/data_module.py | 65 ++++++++++++------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 690fb6057..7b0d45312 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -171,39 +171,38 @@ def _prepare_metadata(self): dict dictionary containing the following keys: - * ``encoder_cat``: Number of categorical variables in the encoder. - Computed as ``len(self.categorical_indices)``, which counts the - categorical feature indices. - * ``encoder_cont``: Number of continuous variables in the encoder. - Computed as ``len(self.continuous_indices)``, which counts the - continuous feature indices. - * ``decoder_cat``: Number of categorical variables in the decoder that - are known in advance. - Computed by filtering ``self.time_series_metadata["cols"]["x"]`` - where col_type == "C"(categorical) and col_known == "K" (known) - * ``decoder_cont``: Number of continuous variables in the decoder that - are known in advance. - Computed by filtering ``self.time_series_metadata["cols"]["x"]`` - where col_type == "F"(continuous) and col_known == "K"(known) - * ``target``: Number of target variables. - Computed as ``len(self.time_series_metadata["cols"]["y"])``, which - gives the number of output target columns.. - * ``static_categorical_features``: Number of static categorical features - Computed by filtering ``self.time_series_metadata["cols"]["st"]`` - (static features) where col_type == "C" (categorical). - * ``static_continuous_features``: Number of static continuous features - Computed as difference of - ``len(self.time_series_metadata["cols"]["st"])`` (static features) - and static_categorical_features that gives static continuous feature - * ``max_encoder_length``: maximum encoder length - Taken directly from `self.max_encoder_length`. - * ``max_prediction_length``: maximum prediction length - Taken directly from `self.max_prediction_length`. - * ``min_encoder_length``: minimum encoder length - Taken directly from `self.min_encoder_length`. - * ``min_prediction_length``: minimum prediction length - Taken directly from `self.min_prediction_length`. - + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. """ encoder_cat_count = len(self.categorical_indices) encoder_cont_count = len(self.continuous_indices) From c3b40ad0f3298e84b70b12a050614da3909799e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:50:43 +0200 Subject: [PATCH 25/33] backwards compat of private/public attrs --- pytorch_forecasting/data/data_module.py | 8 ++++++++ pytorch_forecasting/data/timeseries/_timeseries_v2.py | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 7b0d45312..c8252014d 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -163,6 +163,14 @@ def __init__( else: self.continuous_indices.append(idx) + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.min_prediction_length = self._min_prediction_length + self.min_encoder_length = self._min_encoder_length + self.categorical_encoders = self._categorical_encoders + self.scalers = self._scalers + self.target_normalizer = self._target_normalizer + def _prepare_metadata(self): """Prepare metadata for model initialisation. diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 178b273bc..d5ecbcabb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -140,6 +140,16 @@ def __init__( self._prepare_metadata() + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.group = self._group + self.target = self._target + self.num = self._num + self.cat = self._cat + self.known = self._known + self.unknown = self._unknown + self.static = self._static + def _prepare_metadata(self): """Prepare metadata for the dataset. From 38c28dc031ecebddca3385bb0f1c58b4423a1b35 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 18:51:05 +0530 Subject: [PATCH 26/33] add tests --- .../tft_version_two.py | 38 +- tests/test_models/test_tft_v2.py | 367 ++++++++++++++++++ 2 files changed, 398 insertions(+), 7 deletions(-) create mode 100644 tests/test_models/test_tft_v2.py diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 2bfe407d7..1a1634356 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -157,11 +157,11 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self.static_context_linear is not None: static_cat = x.get( "static_categorical_features", - torch.zeros(batch_size, 0, device=self.device), + torch.zeros(batch_size, 1, 0, device=self.device), ) static_cont = x.get( "static_continuous_features", - torch.zeros(batch_size, 0, device=self.device), + torch.zeros(batch_size, 1, 0, device=self.device), ) if static_cat.size(2) == 0 and static_cont.size(2) == 0: @@ -180,17 +180,41 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: static_context = static_context.view(batch_size, self.hidden_size) else: - static_input = torch.cat([static_cont, static_cat], dim=1).to( + static_input = torch.cat([static_cont, static_cat], dim=2).to( dtype=self.static_context_linear.weight.dtype ) static_context = self.static_context_linear(static_input) static_context = static_context.view(batch_size, self.hidden_size) - encoder_weights = self.encoder_var_selection(encoder_input) - encoder_input = encoder_input * encoder_weights + if self.encoder_var_selection is not None: + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + else: + if self.encoder_input_dim == 0: + encoder_input = torch.zeros( + batch_size, + self.max_encoder_length, + 1, + device=self.device, + dtype=encoder_input.dtype, + ) + else: + encoder_input = encoder_input - decoder_weights = self.decoder_var_selection(decoder_input) - decoder_input = decoder_input * decoder_weights + if self.decoder_var_selection is not None: + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + else: + if self.decoder_input_dim == 0: + decoder_input = torch.zeros( + batch_size, + self.max_prediction_length, + 1, + device=self.device, + dtype=decoder_input.dtype, + ) + else: + decoder_input = decoder_input if static_context is not None: encoder_static_context = static_context.unsqueeze(1).expand( diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py new file mode 100644 index 000000000..e69d3d06d --- /dev/null +++ b/tests/test_models/test_tft_v2.py @@ -0,0 +1,367 @@ +import numpy as np +import pandas as pd +import pytest +import torch +import torch.nn as nn + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT + +BATCH_SIZE_TEST = 2 +MAX_ENCODER_LENGTH_TEST = 10 +MAX_PREDICTION_LENGTH_TEST = 5 +HIDDEN_SIZE_TEST = 8 +OUTPUT_SIZE_TEST = 1 +ATTENTION_HEAD_SIZE_TEST = 2 +NUM_LAYERS_TEST = 1 +DROPOUT_TEST = 0.1 + + +def get_default_test_metadata( + enc_cont=2, + enc_cat=1, + dec_cont=1, + dec_cat=1, + static_cat=1, + static_cont=1, + output_size=OUTPUT_SIZE_TEST, +): + return { + "max_encoder_length": MAX_ENCODER_LENGTH_TEST, + "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, + "encoder_cont": enc_cont, + "encoder_cat": enc_cat, + "decoder_cont": dec_cont, + "decoder_cat": dec_cat, + "static_categorical_features": static_cat, + "static_continuous_features": static_cont, + "target": output_size, + } + + +def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + def _get_dim_val(key): + return metadata.get(key, 0) + + x = { + "encoder_cont": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cont"), + device=device, + ), + "encoder_cat": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cat"), + device=device, + ), + "decoder_cont": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cont"), + device=device, + ), + "decoder_cat": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cat"), + device=device, + ), + "static_categorical_features": torch.randn( + batch_size, 1, _get_dim_val("static_categorical_features"), device=device + ), + "static_continuous_features": torch.randn( + batch_size, 1, _get_dim_val("static_continuous_features"), device=device + ), + "encoder_lengths": torch.full( + (batch_size,), + metadata["max_encoder_length"], + dtype=torch.long, + device=device, + ), + "decoder_lengths": torch.full( + (batch_size,), + metadata["max_prediction_length"], + dtype=torch.long, + device=device, + ), + "groups": torch.arange(batch_size, device=device).unsqueeze(1), + "encoder_time_idx": torch.stack( + [torch.arange(metadata["max_encoder_length"], device=device)] * batch_size + ), + "decoder_time_idx": torch.stack( + [ + torch.arange( + metadata["max_encoder_length"], + metadata["max_encoder_length"] + metadata["max_prediction_length"], + device=device, + ) + ] + * batch_size + ), + "target_scale": torch.ones((batch_size, 1), device=device), + } + return x + + +dummy_loss_for_test = nn.MSELoss() + + +@pytest.fixture(scope="module") +def tft_model_params_fixture_func(): + return { + "loss": dummy_loss_for_test, + "hidden_size": HIDDEN_SIZE_TEST, + "num_layers": NUM_LAYERS_TEST, + "attention_head_size": ATTENTION_HEAD_SIZE_TEST, + "dropout": DROPOUT_TEST, + "output_size": OUTPUT_SIZE_TEST, + } + + +class TestTFTInitialization: + def test_basic_initialization(self, tft_model_params_fixture_func): + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert ( + model.encoder_input_dim + == metadata["encoder_cont"] + metadata["encoder_cat"] + ) + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] + ) + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + def test_initialization_no_time_varying_features( + self, tft_model_params_fixture_func + ): + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + def test_initialization_no_static_features(self, tft_model_params_fixture_func): + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +class TestTFTForwardPass: + @pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], + ) + def test_forward_pass_configs( + self, tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k + ): + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" + + +@pytest.fixture +def sample_pandas_data_for_test(): + """Create sample data ensuring all feature columns are numeric (float32).""" + series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 + num_groups = 6 + data = [] + + for i in range(num_groups): + static_cont_val = np.float32(i * 10.0) + static_cat_code = np.float32(i % 2) + + df_group = pd.DataFrame( + { + "time_idx": np.arange(series_len, dtype=np.int64), + "group_id_str": np.repeat(f"g{i}", series_len), + "target": np.random.rand(series_len).astype(np.float32) + i, + "enc_cont1": np.random.rand(series_len).astype(np.float32), + "enc_cat1_codes": np.random.randint(0, 3, series_len).astype( + np.float32 + ), + "dec_known_cont": np.sin(np.arange(series_len) / 5.0).astype( + np.float32 + ), + "dec_known_cat_codes": np.random.randint(0, 2, series_len).astype( + np.float32 + ), + "static_cat_feat_codes": np.full( + series_len, static_cat_code, dtype=np.float32 + ), + "static_cont_feat": np.full( + series_len, static_cont_val, dtype=np.float32 + ), + } + ) + data.append(df_group) + + df = pd.concat(data, ignore_index=True) + + df["group_id"] = df["group_id_str"].astype("category") + df.drop(columns=["group_id_str"], inplace=True) + + return df + + +@pytest.fixture +def timeseries_obj_for_test(sample_pandas_data_for_test): + df = sample_pandas_data_for_test + + return TimeSeries( + data=df, + time="time_idx", + target="target", + group=["group_id"], + num=[ + "enc_cont1", + "enc_cat1_codes", + "dec_known_cont", + "dec_known_cat_codes", + "static_cat_feat_codes", + "static_cont_feat", + ], + cat=[], + known=["dec_known_cont", "dec_known_cat_codes", "time_idx"], + static=["static_cat_feat_codes", "static_cont_feat"], + ) + + +@pytest.fixture +def data_module_for_test(timeseries_obj_for_test): + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=timeseries_obj_for_test, + batch_size=BATCH_SIZE_TEST, + max_encoder_length=MAX_ENCODER_LENGTH_TEST, + max_prediction_length=MAX_PREDICTION_LENGTH_TEST, + train_val_test_split=(0.5, 0.25, 0.25), + num_workers=0, # Added for consistency + ) + dm.setup("fit") + dm.setup("test") + return dm + + +class TestTFTWithDataModule: + def test_model_with_datamodule_integration( + self, tft_model_params_fixture_func, data_module_for_test + ): + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert ( + batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + ) + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert ( + batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + ) + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + # assert ( + # batch_x["static_categorical_features"].shape[2] + # == model_metadata_from_dm["static_categorical_features"] + # ) + # assert ( + # batch_x["static_continuous_features"].shape[2] + # == model_metadata_from_dm["static_continuous_features"] + # ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) From 9d80eb822e47c92e3b542cd70fe98103e00bd829 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:10:57 +0530 Subject: [PATCH 27/33] add tests --- tests/test_models/test_tft_v2.py | 311 +++++++++++++++---------------- 1 file changed, 152 insertions(+), 159 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index e69d3d06d..0455ad818 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -121,95 +121,92 @@ def tft_model_params_fixture_func(): } -class TestTFTInitialization: - def test_basic_initialization(self, tft_model_params_fixture_func): - metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.hidden_size == HIDDEN_SIZE_TEST - assert model.num_layers == NUM_LAYERS_TEST - assert hasattr(model, "metadata") and model.metadata == metadata - assert ( - model.encoder_input_dim - == metadata["encoder_cont"] + metadata["encoder_cat"] - ) - assert ( - model.static_input_dim - == metadata["static_categorical_features"] - + metadata["static_continuous_features"] - ) - assert isinstance(model.lstm_encoder, nn.LSTM) - assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) - assert isinstance(model.self_attention, nn.MultiheadAttention) - if hasattr(model, "hparams") and model.hparams: - assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST - assert model.output_size == OUTPUT_SIZE_TEST - - def test_initialization_no_time_varying_features( - self, tft_model_params_fixture_func - ): - metadata = get_default_test_metadata( - enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST - ) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.encoder_input_dim == 0 - assert model.encoder_var_selection is None - assert model.lstm_encoder.input_size == 1 - assert model.decoder_input_dim == 0 - assert model.decoder_var_selection is None - assert model.lstm_decoder.input_size == 1 - - def test_initialization_no_static_features(self, tft_model_params_fixture_func): - metadata = get_default_test_metadata( - static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST - ) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.static_input_dim == 0 - assert model.static_context_linear is None - - -class TestTFTForwardPass: - @pytest.mark.parametrize( - "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", - [ - (2, 1, 1, 1, 1, 1), - (2, 0, 1, 0, 0, 0), - (0, 0, 0, 0, 1, 1), - (0, 0, 0, 0, 0, 0), - (1, 0, 1, 0, 1, 0), - (1, 0, 1, 0, 0, 1), - ], +# Converted from TestTFTInitialization class +def test_basic_initialization(tft_model_params_fixture_func): + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert model.encoder_input_dim == metadata["encoder_cont"] + metadata["encoder_cat"] + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] ) - def test_forward_pass_configs( - self, tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k - ): - current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] - metadata = get_default_test_metadata( - enc_cont=enc_c, - enc_cat=enc_k, - dec_cont=dec_c, - dec_cat=dec_k, - static_cat=stat_c, - static_cont=stat_k, - output_size=current_tft_actual_output_size, - ) - model_params = tft_model_params_fixture_func.copy() - model_params["output_size"] = current_tft_actual_output_size - model = TFT(**model_params, metadata=metadata) - model.eval() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - x = create_tft_input_batch_for_test( - metadata, batch_size=BATCH_SIZE_TEST, device=device - ) - output_dict = model(x) - predictions = output_dict["prediction"] - assert predictions.shape == ( - BATCH_SIZE_TEST, - MAX_PREDICTION_LENGTH_TEST, - current_tft_actual_output_size, - ) - assert not torch.isnan(predictions).any(), "NaNs in prediction" - assert not torch.isinf(predictions).any(), "Infs in prediction" + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + +def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + +def test_initialization_no_static_features(tft_model_params_fixture_func): + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +# Converted from TestTFTForwardPass class +@pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], +) +def test_forward_pass_configs( + tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k +): + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" @pytest.fixture @@ -294,74 +291,70 @@ def data_module_for_test(timeseries_obj_for_test): return dm -class TestTFTWithDataModule: - def test_model_with_datamodule_integration( - self, tft_model_params_fixture_func, data_module_for_test - ): - dm = data_module_for_test - model_metadata_from_dm = dm.metadata - - assert ( - model_metadata_from_dm["encoder_cont"] == 6 - ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" - assert ( - model_metadata_from_dm["encoder_cat"] == 0 - ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" - assert ( - model_metadata_from_dm["decoder_cont"] == 2 - ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" - assert ( - model_metadata_from_dm["decoder_cat"] == 0 - ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" - assert ( - model_metadata_from_dm["static_categorical_features"] == 0 - ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" - assert ( - model_metadata_from_dm["static_continuous_features"] == 2 - ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" - assert model_metadata_from_dm["target"] == 1 - - tft_init_args = tft_model_params_fixture_func.copy() - tft_init_args["output_size"] = model_metadata_from_dm["target"] - model = TFT(**tft_init_args, metadata=model_metadata_from_dm) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - model.eval() - - train_loader = dm.train_dataloader() - batch_x, batch_y = next(iter(train_loader)) - - actual_batch_size = batch_x["encoder_cont"].shape[0] - batch_x = {k: v.to(device) for k, v in batch_x.items()} - batch_y = batch_y.to(device) - - assert ( - batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] - ) - assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] - assert ( - batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] - ) - assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] - # assert ( - # batch_x["static_categorical_features"].shape[2] - # == model_metadata_from_dm["static_categorical_features"] - # ) - # assert ( - # batch_x["static_continuous_features"].shape[2] - # == model_metadata_from_dm["static_continuous_features"] - # ) - - output_dict = model(batch_x) - predictions = output_dict["prediction"] - assert predictions.shape == ( - actual_batch_size, - MAX_PREDICTION_LENGTH_TEST, - model_metadata_from_dm["target"], - ) - assert not torch.isnan(predictions).any() - assert batch_y.shape == ( - actual_batch_size, - MAX_PREDICTION_LENGTH_TEST, - model_metadata_from_dm["target"], - ) +# Converted from TestTFTWithDataModule class +def test_model_with_datamodule_integration( + tft_model_params_fixture_func, data_module_for_test +): + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + # assert ( + # batch_x["static_categorical_features"].shape[2] + # == model_metadata_from_dm["static_categorical_features"] + # ) + # assert ( + # batch_x["static_continuous_features"].shape[2] + # == model_metadata_from_dm["static_continuous_features"] + # ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) From a8ccfe36d383191ba6bd23902543aed40dbe0d39 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:17:36 +0530 Subject: [PATCH 28/33] add tests --- tests/test_models/test_tft_v2.py | 37 +++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index 0455ad818..ae74d59fc 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -123,6 +123,14 @@ def tft_model_params_fixture_func(): # Converted from TestTFTInitialization class def test_basic_initialization(tft_model_params_fixture_func): + """Test basic initialization of the TFT model with default metadata. + + Verifies: + - Model attributes match the provided metadata (e.g., hidden_size, num_layers). + - Proper construction of key model components (LSTM, attention, etc.). + - Correct dimensionality of input layers based on metadata. + - Model retains metadata and hyperparameters as expected. + """ metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) model = TFT(**tft_model_params_fixture_func, metadata=metadata) assert model.hidden_size == HIDDEN_SIZE_TEST @@ -143,6 +151,13 @@ def test_basic_initialization(tft_model_params_fixture_func): def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + """Test TFT initialization with no time-varying (encoder/decoder) features. + + Verifies: + - Model handles zero encoder/decoder input dimensions correctly. + - Skips creation of encoder/decoder variable selection networks. + - Defaults to input size 1 for LSTMs when no time-varying features exist. + """ metadata = get_default_test_metadata( enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST ) @@ -156,6 +171,12 @@ def test_initialization_no_time_varying_features(tft_model_params_fixture_func): def test_initialization_no_static_features(tft_model_params_fixture_func): + """Test TFT initialization with no static features. + + Verifies: + - Model static input dim is 0. + - Static context linear layer is not created. + """ metadata = get_default_test_metadata( static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST ) @@ -179,6 +200,13 @@ def test_initialization_no_static_features(tft_model_params_fixture_func): def test_forward_pass_configs( tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k ): + """Test TFT forward pass across multiple feature configurations. + + Verifies: + - Model can forward pass without errors for varying combinations of input types. + - Output prediction tensor has expected shape. + - Output contains no NaNs or infinities. + """ current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] metadata = get_default_test_metadata( enc_cont=enc_c, @@ -211,7 +239,6 @@ def test_forward_pass_configs( @pytest.fixture def sample_pandas_data_for_test(): - """Create sample data ensuring all feature columns are numeric (float32).""" series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 num_groups = 6 data = [] @@ -295,6 +322,14 @@ def data_module_for_test(timeseries_obj_for_test): def test_model_with_datamodule_integration( tft_model_params_fixture_func, data_module_for_test ): + """Integration test to ensure TFT works correctly with data module. + + Verifies: + - Metadata inferred from data module matches expected input dimensions. + - Model processes real dataloader batches correctly. + - Output and target tensors from model and data module align in shape. + - No NaNs in predictions. + """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From f900ba5e4d4912573e7dc79c398386e683d5e807 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:24:21 +0530 Subject: [PATCH 29/33] add more docstrings --- tests/test_models/test_tft_v2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index ae74d59fc..d79eac874 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -27,6 +27,7 @@ def get_default_test_metadata( static_cont=1, output_size=OUTPUT_SIZE_TEST, ): + """Return a dict representing default metadata for TFT model initialization.""" return { "max_encoder_length": MAX_ENCODER_LENGTH_TEST, "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, @@ -41,6 +42,8 @@ def get_default_test_metadata( def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + """Create a synthetic input batch dictionary for testing TFT forward passes.""" + def _get_dim_val(key): return metadata.get(key, 0) @@ -111,6 +114,7 @@ def _get_dim_val(key): @pytest.fixture(scope="module") def tft_model_params_fixture_func(): + """Create a default set of model parameters for TFT.""" return { "loss": dummy_loss_for_test, "hidden_size": HIDDEN_SIZE_TEST, @@ -121,7 +125,6 @@ def tft_model_params_fixture_func(): } -# Converted from TestTFTInitialization class def test_basic_initialization(tft_model_params_fixture_func): """Test basic initialization of the TFT model with default metadata. @@ -239,6 +242,7 @@ def test_forward_pass_configs( @pytest.fixture def sample_pandas_data_for_test(): + """Create synthetic multivariate time series data as a pandas DataFrame.""" series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 num_groups = 6 data = [] @@ -282,6 +286,7 @@ def sample_pandas_data_for_test(): @pytest.fixture def timeseries_obj_for_test(sample_pandas_data_for_test): + """Convert sample DataFrame into a TimeSeries object.""" df = sample_pandas_data_for_test return TimeSeries( @@ -305,6 +310,7 @@ def timeseries_obj_for_test(sample_pandas_data_for_test): @pytest.fixture def data_module_for_test(timeseries_obj_for_test): + """Initialize and sets up an EncoderDecoderTimeSeriesDataModule.""" dm = EncoderDecoderTimeSeriesDataModule( time_series_dataset=timeseries_obj_for_test, batch_size=BATCH_SIZE_TEST, @@ -318,7 +324,6 @@ def data_module_for_test(timeseries_obj_for_test): return dm -# Converted from TestTFTWithDataModule class def test_model_with_datamodule_integration( tft_model_params_fixture_func, data_module_for_test ): From ed1b79936df9c4cb18c29393f964228997001b98 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:26:40 +0530 Subject: [PATCH 30/33] add note about the commented out tests --- tests/test_models/test_tft_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index d79eac874..57a50e75e 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -188,7 +188,6 @@ def test_initialization_no_static_features(tft_model_params_fixture_func): assert model.static_context_linear is None -# Converted from TestTFTForwardPass class @pytest.mark.parametrize( "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", [ @@ -334,6 +333,8 @@ def test_model_with_datamodule_integration( - Model processes real dataloader batches correctly. - Output and target tensors from model and data module align in shape. - No NaNs in predictions. + + Note: The commented out tests are to test a bug in data_module """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From c0ceb8a16703573144e3d0bd3aa6ab978157a341 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 17 May 2025 02:08:06 +0530 Subject: [PATCH 31/33] add the commented out tests --- tests/test_models/test_tft_v2.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index 57a50e75e..f541082ce 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -316,7 +316,6 @@ def data_module_for_test(timeseries_obj_for_test): max_encoder_length=MAX_ENCODER_LENGTH_TEST, max_prediction_length=MAX_PREDICTION_LENGTH_TEST, train_val_test_split=(0.5, 0.25, 0.25), - num_workers=0, # Added for consistency ) dm.setup("fit") dm.setup("test") @@ -377,14 +376,14 @@ def test_model_with_datamodule_integration( assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] - # assert ( - # batch_x["static_categorical_features"].shape[2] - # == model_metadata_from_dm["static_categorical_features"] - # ) - # assert ( - # batch_x["static_continuous_features"].shape[2] - # == model_metadata_from_dm["static_continuous_features"] - # ) + assert ( + batch_x["static_categorical_features"].shape[2] + == model_metadata_from_dm["static_categorical_features"] + ) + assert ( + batch_x["static_continuous_features"].shape[2] + == model_metadata_from_dm["static_continuous_features"] + ) output_dict = model(batch_x) predictions = output_dict["prediction"] From 3828c260d4b32ee7fcd9fc300776126c70f6a3b6 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 17 May 2025 02:09:16 +0530 Subject: [PATCH 32/33] remove note --- tests/test_models/test_tft_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index f541082ce..791ea10ef 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -332,8 +332,6 @@ def test_model_with_datamodule_integration( - Model processes real dataloader batches correctly. - Output and target tensors from model and data module align in shape. - No NaNs in predictions. - - Note: The commented out tests are to test a bug in data_module """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From 30b541b2910c461e3e488e19137c1242c0b0627b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 21 May 2025 00:52:29 +0530 Subject: [PATCH 33/33] make the modules private --- .../{base_model_refactor.py => _base_model_v2.py} | 13 +++++++++++++ .../{tft_version_two.py => _tft_v2.py} | 2 +- .../test_models/{test_tft_v2.py => _test_tft_v2.py} | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) rename pytorch_forecasting/models/base/{base_model_refactor.py => _base_model_v2.py} (93%) rename pytorch_forecasting/models/temporal_fusion_transformer/{tft_version_two.py => _tft_v2.py} (99%) rename tests/test_models/{test_tft_v2.py => _test_tft_v2.py} (99%) diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/_base_model_v2.py similarity index 93% rename from pytorch_forecasting/models/base/base_model_refactor.py rename to pytorch_forecasting/models/base/_base_model_v2.py index ccd2c2600..ddefc29fb 100644 --- a/pytorch_forecasting/models/base/base_model_refactor.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union +from warnings import warn from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -53,6 +54,18 @@ def __init__( self.lr_scheduler_params = ( lr_scheduler_params if lr_scheduler_params is not None else {} ) + self.model_name = self.__class__.__name__ + warn( + f"The Model '{self.model_name}' is part of an experimental rework" + "of the pytorch-forecasting model layer, scheduled for release with v2.0.0." + " The API is not stable and may change without prior warning. " + "This class is intended for beta testing and as a basic skeleton, " + "but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py similarity index 99% rename from pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py rename to pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py index 1a1634356..a0cf7d39e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch.optim import Optimizer -from pytorch_forecasting.models.base.base_model_refactor import BaseModel +from pytorch_forecasting.models.base._base_model_v2 import BaseModel class TFT(BaseModel): diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/_test_tft_v2.py similarity index 99% rename from tests/test_models/test_tft_v2.py rename to tests/test_models/_test_tft_v2.py index 791ea10ef..13d92d5db 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/_test_tft_v2.py @@ -6,7 +6,7 @@ from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.timeseries import TimeSeries -from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT +from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT BATCH_SIZE_TEST = 2 MAX_ENCODER_LENGTH_TEST = 10