diff --git a/src/timesfm/timesfm_base.py b/src/timesfm/timesfm_base.py index 73b1495..5b43e8f 100644 --- a/src/timesfm/timesfm_base.py +++ b/src/timesfm/timesfm_base.py @@ -37,45 +37,38 @@ def process_group(key, group, value_name, forecast_context_len): - group = group.tail(forecast_context_len) - return np.array(group[value_name], dtype=np.float32), key + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key def moving_average(arr, window_size): - """Calculates the moving average using NumPy's convolution function.""" - # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), "valid") / - window_size) - return [smoothed_arr, arr - smoothed_arr] + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), "valid") / + window_size) + return [smoothed_arr, arr - smoothed_arr] def freq_map(freq: str): - """Returns the frequency map for the given frequency string.""" - freq = str.upper(freq) - if freq.endswith("MS"): - return 1 - elif freq.endswith(("H", "T", "MIN", "D", "B", "U", "S")): - return 0 - elif ( - freq.endswith(("W", "M")) - or freq.startswith("W-") - or (freq.startswith("M") and len(freq) == 2) - ): - return 1 - elif ( - freq.endswith(("Y", "Q", "A")) - or freq.startswith("Y-") - or freq.startswith("Q-") - or freq.startswith("A-") - ): - return 2 - else: - raise ValueError(f"Invalid frequency: {freq}") + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if freq.endswith("MS"): + return 1 + elif freq.endswith(("H", "T", "MIN", "D", "B", "U", "S")): + return 0 + elif (freq.endswith(("W", "M")) or freq.startswith("W-") + or (freq.startswith("M") and len(freq) == 2)): + return 1 + elif (freq.endswith(("Y", "Q", "A")) or freq.startswith("Y-") + or freq.startswith("Q-") or freq.startswith("A-")): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") def strip_leading_nans(arr): - """ + """ Removes contiguous NaN values from the beginning of a NumPy array. Args: @@ -86,13 +79,13 @@ def strip_leading_nans(arr): If the array is all NaNs or empty, returns an empty array. """ - isnan = np.isnan(arr) - first_valid_index = np.argmax(~isnan) - return arr[first_valid_index:] + isnan = np.isnan(arr) + first_valid_index = np.argmax(~isnan) + return arr[first_valid_index:] def linear_interpolation(arr): - """ + """ Performs linear interpolation to fill NaN values in a 1D numpy array. Args: @@ -105,45 +98,44 @@ def linear_interpolation(arr): Returns the original array if there are no NaN values. """ - nans = np.isnan(arr) - if not np.any(nans): # Check if there are any NaNs + nans = np.isnan(arr) + if not np.any(nans): # Check if there are any NaNs + return arr + + def x(z): + return z.nonzero()[0] + + nans_indices = x(nans) + non_nans_indices = x(~nans) + non_nans_values = arr[~nans] + + try: + arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values) + except ValueError: + if len(non_nans_values) > 0: + mu = np.nanmean(arr) + else: + mu = 0.0 + arr = np.where(np.isfinite(arr), arr, mu) return arr - def x(z): - return z.nonzero()[0] - - nans_indices = x(nans) - non_nans_indices = x(~nans) - non_nans_values = arr[~nans] - - try: - arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values) - except ValueError: - if len(non_nans_values) > 0: - mu = np.nanmean(arr) - else: - mu = 0.0 - arr = np.where(np.isfinite(arr), arr, mu) - return arr - # Per time series normalization: forward. def _normalize(batch): - stats = [ - (np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch - ] - new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] - return new_batch, stats + stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) + for x in batch] + new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] + return new_batch, stats # Per time series normalization: inverse. def _renormalize(batch, stats): - return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] + return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] @dataclasses.dataclass(kw_only=True) class TimesFmHparams: - """Hparams used to initialize a TimesFM model for inference. + """Hparams used to initialize a TimesFM model for inference. These are the sufficient subset of hparams to configure TimesFM inference agnostic to the checkpoint version, and are not necessarily the same as the @@ -165,24 +157,24 @@ class TimesFmHparams: quantiles: Which quantiles are output by the model. """ - context_len: int = 512 - horizon_len: int = 128 - input_patch_len: int = 32 - output_patch_len: int = 128 - num_layers: int = 20 - num_heads: int = 16 - model_dims: int = 1280 - per_core_batch_size: int = 32 - backend: Literal["cpu", "gpu", "tpu"] = "cpu" - quantiles: Sequence[float] | None = DEFAULT_QUANTILES - use_positional_embedding: bool = True - # Hparams beyond the model. - point_forecast_mode: Literal["mean", "median"] = "median" + context_len: int = 512 + horizon_len: int = 128 + input_patch_len: int = 32 + output_patch_len: int = 128 + num_layers: int = 20 + num_heads: int = 16 + model_dims: int = 1280 + per_core_batch_size: int = 32 + backend: Literal["cpu", "gpu", "tpu"] = "cpu" + quantiles: Sequence[float] | None = DEFAULT_QUANTILES + use_positional_embedding: bool = True + # Hparams beyond the model. + point_forecast_mode: Literal["mean", "median"] = "median" @dataclasses.dataclass(kw_only=True) class TimesFmCheckpoint: - """Checkpoint used to initialize a TimesFM model for inference. + """Checkpoint used to initialize a TimesFM model for inference. Attributes: version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. @@ -194,16 +186,16 @@ class TimesFmCheckpoint: step: If provided, step of the checkpoint. """ - version: str = "jax" - path: str | None = None - huggingface_repo_id: str | None = None - type: Any = None - step: int | None = None - local_dir: str | None = None + version: str = "jax" + path: str | None = None + huggingface_repo_id: str | None = None + type: Any = None + step: int | None = None + local_dir: str | None = None class TimesFmBase: - """Base TimesFM forecast API for inference. + """Base TimesFM forecast API for inference. This class is the scaffolding for calling TimesFM forecast. To properly use: 1. Create an instance with the correct hyperparameters of a TimesFM model. @@ -211,53 +203,53 @@ class TimesFmBase: 3. Call `forecast` for inference. """ - def _logging(self, s): - print(s) + def _logging(self, s): + print(s) - def __post_init__(self) -> None: - """Additional initialization for subclasses before checkpoint loading.""" - pass + def __post_init__(self) -> None: + """Additional initialization for subclasses before checkpoint loading.""" + pass - def __init__(self, hparams: TimesFmHparams, - checkpoint: TimesFmCheckpoint) -> None: - """Initializes the TimesFM forecast API. + def __init__(self, hparams: TimesFmHparams, + checkpoint: TimesFmCheckpoint) -> None: + """Initializes the TimesFM forecast API. Args: hparams: Hyperparameters of the model. checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide which TimesFM version to use. """ - self.hparams = hparams - - # Expand hparams for conciseness within the model code. - self.context_len = hparams.context_len - self.horizon_len = hparams.horizon_len - self.input_patch_len = hparams.input_patch_len - self.output_patch_len = hparams.output_patch_len - self.num_layers = hparams.num_layers - self.model_dims = hparams.model_dims - self.backend = hparams.backend - self.quantiles = hparams.quantiles - self.num_heads = hparams.num_heads - self.use_pos_emb = hparams.use_positional_embedding - - # Rewrite these values in __post_init__ for SPMD. - self.num_cores = 1 - self.per_core_batch_size = hparams.per_core_batch_size - self.global_batch_size = hparams.per_core_batch_size - - self._horizon_start = self.context_len - self.input_patch_len - self.__post_init__() - self.load_from_checkpoint(checkpoint) - - def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: - """Loads a checkpoint and compiles the decoder.""" - raise NotImplementedError("`load_from_checkpoint` is not implemented.") - - def _preprocess( - self, inputs: Sequence[np.ndarray], - freq: Sequence[int]) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: - """Formats and pads raw inputs to feed into the model. + self.hparams = hparams + + # Expand hparams for conciseness within the model code. + self.context_len = hparams.context_len + self.horizon_len = hparams.horizon_len + self.input_patch_len = hparams.input_patch_len + self.output_patch_len = hparams.output_patch_len + self.num_layers = hparams.num_layers + self.model_dims = hparams.model_dims + self.backend = hparams.backend + self.quantiles = hparams.quantiles + self.num_heads = hparams.num_heads + self.use_pos_emb = hparams.use_positional_embedding + + # Rewrite these values in __post_init__ for SPMD. + self.num_cores = 1 + self.per_core_batch_size = hparams.per_core_batch_size + self.global_batch_size = hparams.per_core_batch_size + + self._horizon_start = self.context_len - self.input_patch_len + self.__post_init__() + self.load_from_checkpoint(checkpoint) + + def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: + """Loads a checkpoint and compiles the decoder.""" + raise NotImplementedError("`load_from_checkpoint` is not implemented.") + + def _preprocess( + self, inputs: Sequence[np.ndarray], freq: Sequence[int] + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """Formats and pads raw inputs to feed into the model. This function both pads each time series to match the context length, and pads the inputs to meet the SPMD shape requirement. @@ -276,50 +268,53 @@ def _preprocess( number (a multiple of `batch_size`) of examples. """ - input_ts, input_padding, inp_freq = [], [], [] - - pmap_pad = ((len(inputs) - 1) // self.global_batch_size + - 1) * self.global_batch_size - len(inputs) - - for i, ts in enumerate(inputs): - input_len = ts.shape[0] - padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) - if input_len < self.context_len: - num_front_pad = self.context_len - input_len - ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], - axis=0) - padding = np.concatenate( - [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) - elif input_len > self.context_len: - ts = ts[-self.context_len:] - padding = padding[-(self.context_len + self.horizon_len):] - - input_ts.append(ts) - input_padding.append(padding) - inp_freq.append(freq[i]) - - # Padding the remainder batch. - for _ in range(pmap_pad): - input_ts.append(input_ts[-1]) - input_padding.append(input_padding[-1]) - inp_freq.append(inp_freq[-1]) - - return ( - np.stack(input_ts, axis=0), - np.stack(input_padding, axis=0), - np.array(inp_freq).astype(np.int32).reshape(-1, 1), - pmap_pad, - ) - - def _forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. + input_ts, input_padding, inp_freq = [], [], [] + + pmap_pad = ((len(inputs) - 1) // self.global_batch_size + + 1) * self.global_batch_size - len(inputs) + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len, ), + dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate( + [np.zeros(shape=(num_front_pad, ), dtype=float), ts], + axis=0) + padding = np.concatenate( + [np.ones(shape=(num_front_pad, ), dtype=float), padding], + axis=0) + elif input_len > self.context_len: + ts = ts[-self.context_len:] + padding = padding[-(self.context_len + self.horizon_len):] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + # Padding the remainder batch. + for _ in range(pmap_pad): + input_ts.append(input_ts[-1]) + input_padding.append(input_padding[-1]) + inp_freq.append(inp_freq[-1]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + pmap_pad, + ) + + def _forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. Args: inputs: list of time series forecast contexts. Each context time series @@ -342,18 +337,18 @@ def _forecast( Raises: ValueError: If the checkpoint is not properly loaded. """ - raise NotImplementedError("`_forecast` is not implemented.") - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - normalize: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. + raise NotImplementedError("`_forecast` is not implemented.") + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + normalize: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. Args: inputs: list of time series forecast contexts. Each context time series @@ -378,74 +373,76 @@ def forecast( Raises: ValueError: If the checkpoint is not properly loaded. """ - stats = None - - tmp_inputs = [] - for each_input in inputs: - arr = np.array(each_input) - if not np.isfinite(arr).all(): - arr = np.where(np.isfinite(arr), arr, np.nan) - arr = strip_leading_nans(arr) - arr = linear_interpolation(arr) - tmp_inputs.append(arr) - - inputs = tmp_inputs - if normalize: - inputs, stats = _normalize(inputs) - mean_forecast, quantile_forecast = self._forecast( - inputs, - freq, - window_size, - forecast_context_len, - return_forecast_on_context, - ) - if stats is not None: - stats = np.array(stats) - mu = stats[:, 0] - sigma = stats[:, 1] - mean_forecast = mean_forecast * sigma[:, None] + mu[:, None] - quantile_forecast = (quantile_forecast * sigma[:, None, None] + - mu[:, None, None]) - if self.hparams.point_forecast_mode == "mean": - return mean_forecast, quantile_forecast - elif self.hparams.point_forecast_mode == "median": - if self._median_index == -1: - for i, quantile in enumerate(self.quantiles): - if quantile == 0.5: - self._median_index = i - break - if self._median_index == -1: - raise ValueError("Median (0.5) is not found in the model quantiles:" - f" {self.quantiles}. Please check the hparams.") - return ( - quantile_forecast[:, :, 1 + self._median_index], - quantile_forecast, - ) - else: - raise ValueError( - "Unsupported point forecast mode:" - f" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.") - - def forecast_with_covariates( - self, - inputs: list[Sequence[float]], - dynamic_numerical_covariates: (dict[str, Sequence[Sequence[float]]] | - None) = None, - dynamic_categorical_covariates: (dict[str, Sequence[Sequence[Category]]] | - None) = None, - static_numerical_covariates: dict[str, Sequence[float]] | None = None, - static_categorical_covariates: (dict[str, Sequence[Category]] | - None) = None, - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - xreg_mode: XRegMode = "xreg + timesfm", - normalize_xreg_target_per_input: bool = True, - ridge: float = 0.0, - max_rows_per_col: int = 0, - force_on_cpu: bool = False, - ): - """Forecasts on a list of time series with covariates. + stats = None + + tmp_inputs = [] + for each_input in inputs: + arr = np.array(each_input) + if not np.isfinite(arr).all(): + arr = np.where(np.isfinite(arr), arr, np.nan) + arr = strip_leading_nans(arr) + arr = linear_interpolation(arr) + tmp_inputs.append(arr) + + inputs = tmp_inputs + if normalize: + inputs, stats = _normalize(inputs) + mean_forecast, quantile_forecast = self._forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context, + ) + if stats is not None: + stats = np.array(stats) + mu = stats[:, 0] + sigma = stats[:, 1] + mean_forecast = mean_forecast * sigma[:, None] + mu[:, None] + quantile_forecast = (quantile_forecast * sigma[:, None, None] + + mu[:, None, None]) + if self.hparams.point_forecast_mode == "mean": + return mean_forecast, quantile_forecast + elif self.hparams.point_forecast_mode == "median": + if self._median_index == -1: + for i, quantile in enumerate(self.quantiles): + if quantile == 0.5: + self._median_index = i + break + if self._median_index == -1: + raise ValueError( + "Median (0.5) is not found in the model quantiles:" + f" {self.quantiles}. Please check the hparams.") + return ( + quantile_forecast[:, :, 1 + self._median_index], + quantile_forecast, + ) + else: + raise ValueError( + "Unsupported point forecast mode:" + f" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'." + ) + + def forecast_with_covariates( + self, + inputs: list[Sequence[float]], + dynamic_numerical_covariates: (dict[str, Sequence[Sequence[float]]] + | None) = None, + dynamic_categorical_covariates: ( + dict[str, Sequence[Sequence[Category]]] | None) = None, + static_numerical_covariates: dict[str, Sequence[float]] | None = None, + static_categorical_covariates: (dict[str, Sequence[Category]] + | None) = None, + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + xreg_mode: XRegMode = "xreg + timesfm", + normalize_xreg_target_per_input: bool = True, + ridge: float = 0.0, + max_rows_per_col: int = 0, + force_on_cpu: bool = False, + ): + """Forecasts on a list of time series with covariates. To optimize inference speed, avoid string valued categorical covariates. @@ -463,8 +460,8 @@ def forecast_with_covariates( we do not do decomposition. forecast_context_len: optional max context length. xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm" - fits a model on the residuals of the TimesFM forecast. "timesfm + xreg" - fits a model on the targets then forecasts on the residuals via TimesFM. + fits a model on the targets then forecasts on the residuals via TimesFM. "timesfm + xreg" + fits a model on the residuals of the TimesFM forecast. normalize_xreg_target_per_input: whether to normalize the xreg target per input in the given batch. ridge: ridge penalty for the linear model. @@ -476,184 +473,188 @@ def forecast_with_covariates( the outputs of the xreg. """ - from . import xreg_lib - - # Verify and bookkeep covariates. - if not (dynamic_numerical_covariates or dynamic_categorical_covariates or - static_numerical_covariates or static_categorical_covariates): - raise ValueError( - "At least one of dynamic_numerical_covariates," - " dynamic_categorical_covariates, static_numerical_covariates," - " static_categorical_covariates must be set.") - - # Track the lengths of (1) each input, (2) the part that can be used in the - # linear model, and (3) the horizon. - input_lens, train_lens, test_lens = [], [], [] - - for i, input_ts in enumerate(inputs): - input_len = len(input_ts) - input_lens.append(input_len) - - if xreg_mode == "timesfm + xreg": - # For fitting residuals, no TimesFM forecast on the first patch. - train_lens.append(max(0, input_len - self.input_patch_len)) - elif xreg_mode == "xreg + timesfm": - train_lens.append(input_len) - else: - raise ValueError(f"Unsupported mode: {xreg_mode}") - - if dynamic_numerical_covariates: - test_lens.append( - len(list(dynamic_numerical_covariates.values())[0][i]) - input_len) - elif dynamic_categorical_covariates: - test_lens.append( - len(list(dynamic_categorical_covariates.values())[0][i]) - - input_len) - else: - test_lens.append(self.horizon_len) - - if test_lens[-1] > self.horizon_len: - raise ValueError( - "Forecast requested longer horizon than the model definition " - f"supports: {test_lens[-1]} vs {self.horizon_len}.") - - # Prepare the covariates into train and test. - train_dynamic_numerical_covariates = collections.defaultdict(list) - test_dynamic_numerical_covariates = collections.defaultdict(list) - train_dynamic_categorical_covariates = collections.defaultdict(list) - test_dynamic_categorical_covariates = collections.defaultdict(list) - for covariates, train_covariates, test_covariates in ( - ( - dynamic_numerical_covariates, - train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates, - ), - ( - dynamic_categorical_covariates, - train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates, - ), - ): - if not covariates: - continue - for covariate_name, covariate_values in covariates.items(): - for input_len, train_len, covariate_value in zip( - input_lens, train_lens, covariate_values): - train_covariates[covariate_name].append( - covariate_value[(input_len - train_len):input_len]) - test_covariates[covariate_name].append(covariate_value[input_len:]) - - # Fit models. - if xreg_mode == "timesfm + xreg": - # Forecast via TimesFM then fit a model on the residuals. - mean_outputs, _ = self.forecast( - inputs, - freq, - window_size, - forecast_context_len, - return_forecast_on_context=True, - ) - targets = [ - (np.array(input_ts)[-train_len:] - - mean_output[(self._horizon_start - train_len):self._horizon_start]) - for input_ts, mean_output, train_len in zip(inputs, mean_outputs, - train_lens) - ] - per_instance_stats = None - if normalize_xreg_target_per_input: - targets, per_instance_stats = _normalize(targets) - xregs = xreg_lib.BatchedInContextXRegLinear( - targets=targets, - train_lens=train_lens, - test_lens=test_lens, - train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, - train_dynamic_categorical_covariates= - train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates= - test_dynamic_categorical_covariates, - static_numerical_covariates=static_numerical_covariates, - static_categorical_covariates=static_categorical_covariates, - ).fit( - ridge=ridge, - one_hot_encoder_drop=None if ridge > 0 else "first", - max_rows_per_col=max_rows_per_col, - force_on_cpu=force_on_cpu, - debug_info=False, - assert_covariates=True, - assert_covariate_shapes=True, - ) - if normalize_xreg_target_per_input: - xregs = _renormalize(xregs, per_instance_stats) - outputs = [ - (mean_output[self._horizon_start:(self._horizon_start + test_len)] + - xreg) - for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) - ] - - else: - # Fit a model on the targets then forecast on the residuals via TimesFM. - targets = [ - np.array(input_ts)[-train_len:] - for input_ts, train_len in zip(inputs, train_lens) - ] - per_instance_stats = None - if normalize_xreg_target_per_input: - targets, per_instance_stats = _normalize(targets) - xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( - targets=targets, - train_lens=train_lens, - test_lens=test_lens, - train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, - train_dynamic_categorical_covariates= - train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates= - test_dynamic_categorical_covariates, - static_numerical_covariates=static_numerical_covariates, - static_categorical_covariates=static_categorical_covariates, - ).fit( - ridge=ridge, - one_hot_encoder_drop=None if ridge > 0 else "first", - max_rows_per_col=max_rows_per_col, - force_on_cpu=force_on_cpu, - debug_info=True, - assert_covariates=True, - assert_covariate_shapes=True, - ) - mean_outputs, _ = self.forecast( - [ - target - xreg_on_context - for target, xreg_on_context in zip(targets, xregs_on_context) - ], - freq, - window_size, - forecast_context_len, - return_forecast_on_context=True, - ) - outputs = [ - (mean_output[self._horizon_start:(self._horizon_start + test_len)] + - xreg) - for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) - ] - if normalize_xreg_target_per_input: - outputs = _renormalize(outputs, per_instance_stats) - - return outputs, xregs - - def forecast_on_df( - self, - inputs: pd.DataFrame, - freq: str, - forecast_context_len: int = 0, - value_name: str = "values", - model_name: str = "timesfm", - window_size: int | None = None, - num_jobs: int = 1, - normalize: bool = False, - verbose: bool = True, - ) -> pd.DataFrame: - """Forecasts on a list of time series. + from . import xreg_lib + + # Verify and bookkeep covariates. + if not (dynamic_numerical_covariates or dynamic_categorical_covariates + or static_numerical_covariates + or static_categorical_covariates): + raise ValueError( + "At least one of dynamic_numerical_covariates," + " dynamic_categorical_covariates, static_numerical_covariates," + " static_categorical_covariates must be set.") + + # Track the lengths of (1) each input, (2) the part that can be used in the + # linear model, and (3) the horizon. + input_lens, train_lens, test_lens = [], [], [] + + for i, input_ts in enumerate(inputs): + input_len = len(input_ts) + input_lens.append(input_len) + + if xreg_mode == "timesfm + xreg": + # For fitting residuals, no xreg for the first patch (used as TimesFM context). + train_lens.append(max(0, input_len - self.input_patch_len)) + elif xreg_mode == "xreg + timesfm": + train_lens.append(input_len) + else: + raise ValueError(f"Unsupported mode: {xreg_mode}") + + if dynamic_numerical_covariates: + test_lens.append( + len(list(dynamic_numerical_covariates.values())[0][i]) - + input_len) + elif dynamic_categorical_covariates: + test_lens.append( + len(list(dynamic_categorical_covariates.values())[0][i]) - + input_len) + else: + test_lens.append(self.horizon_len) + + if test_lens[-1] > self.horizon_len: + raise ValueError( + "Forecast requested longer horizon than the model definition " + f"supports: {test_lens[-1]} vs {self.horizon_len}.") + + # Prepare the covariates into train and test. + train_dynamic_numerical_covariates = collections.defaultdict(list) + test_dynamic_numerical_covariates = collections.defaultdict(list) + train_dynamic_categorical_covariates = collections.defaultdict(list) + test_dynamic_categorical_covariates = collections.defaultdict(list) + for covariates, train_covariates, test_covariates in ( + ( + dynamic_numerical_covariates, + train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates, + ), + ( + dynamic_categorical_covariates, + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates, + ), + ): + if not covariates: + continue + for covariate_name, covariate_values in covariates.items(): + for input_len, train_len, covariate_value in zip( + input_lens, train_lens, covariate_values): + train_covariates[covariate_name].append( + covariate_value[(input_len - train_len):input_len]) + test_covariates[covariate_name].append( + covariate_value[input_len:]) + + # Fit models. + if xreg_mode == "timesfm + xreg": + # Forecast via TimesFM then fit a model on the residuals. + mean_outputs, _ = self.forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + targets = [(np.array(input_ts)[-train_len:] - + mean_output[(self._horizon_start - + train_len):self._horizon_start]) + for input_ts, mean_output, train_len in zip( + inputs, mean_outputs, train_lens)] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = _normalize(targets) + xregs = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates= + train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates= + test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates= + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates= + test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=False, + assert_covariates=True, + assert_covariate_shapes=True, + ) + if normalize_xreg_target_per_input: + xregs = _renormalize(xregs, per_instance_stats) + outputs = [(mean_output[self._horizon_start:(self._horizon_start + + test_len)] + xreg) + for mean_output, test_len, xreg in zip( + mean_outputs, test_lens, xregs)] + + else: + # Fit a model on the targets then forecast on the residuals via TimesFM. + targets = [ + np.array(input_ts) + for input_ts, train_len in zip(inputs, train_lens) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = _normalize(targets) + xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates= + train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates= + test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates= + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates= + test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=True, + assert_covariates=True, + assert_covariate_shapes=True, + ) + mean_outputs, _ = self.forecast( + [ + target - xreg_on_context for target, xreg_on_context in + zip(targets, xregs_on_context) + ], + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + outputs = [(mean_output[self._horizon_start:(self._horizon_start + + test_len)] + xreg) + for mean_output, test_len, xreg in zip( + mean_outputs, test_lens, xregs)] + if normalize_xreg_target_per_input: + outputs = _renormalize(outputs, per_instance_stats) + + return outputs, xregs + + def forecast_on_df( + self, + inputs: pd.DataFrame, + freq: str, + forecast_context_len: int = 0, + value_name: str = "values", + model_name: str = "timesfm", + window_size: int | None = None, + num_jobs: int = 1, + normalize: bool = False, + verbose: bool = True, + ) -> pd.DataFrame: + """Forecasts on a list of time series. Args: inputs: A pd.DataFrame of all time series. The dataframe should have a @@ -675,62 +676,63 @@ def forecast_on_df( Returns: Future forecasts dataframe. """ - if not ("unique_id" in inputs.columns and "ds" in inputs.columns and - value_name in inputs.columns): - raise ValueError( - f"DataFrame must have unique_id, ds and {value_name} columns.") - if not forecast_context_len: - forecast_context_len = self.context_len - logging.info("Preprocessing dataframe.") - df_sorted = inputs.sort_values(by=["unique_id", "ds"]) - new_inputs = [] - uids = [] - if num_jobs == 1: - if verbose: - print("Processing dataframe with single process.") - for key, group in df_sorted.groupby("unique_id"): - inp, uid = process_group( - key, - group, - value_name, - forecast_context_len, - ) - new_inputs.append(inp) - uids.append(uid) - else: - if num_jobs == -1: - num_jobs = multiprocessing.cpu_count() - if verbose: - print("Processing dataframe with multiple processes.") - with multiprocessing.Pool(processes=num_jobs) as pool: - results = pool.starmap( - process_group, - [(key, group, value_name, forecast_context_len) - for key, group in df_sorted.groupby("unique_id")], + if not ("unique_id" in inputs.columns and "ds" in inputs.columns + and value_name in inputs.columns): + raise ValueError( + f"DataFrame must have unique_id, ds and {value_name} columns.") + if not forecast_context_len: + forecast_context_len = self.context_len + logging.info("Preprocessing dataframe.") + df_sorted = inputs.sort_values(by=["unique_id", "ds"]) + new_inputs = [] + uids = [] + if num_jobs == 1: + if verbose: + print("Processing dataframe with single process.") + for key, group in df_sorted.groupby("unique_id"): + inp, uid = process_group( + key, + group, + value_name, + forecast_context_len, + ) + new_inputs.append(inp) + uids.append(uid) + else: + if num_jobs == -1: + num_jobs = multiprocessing.cpu_count() + if verbose: + print("Processing dataframe with multiple processes.") + with multiprocessing.Pool(processes=num_jobs) as pool: + results = pool.starmap( + process_group, + [(key, group, value_name, forecast_context_len) + for key, group in df_sorted.groupby("unique_id")], + ) + new_inputs, uids = zip(*results) + if verbose: + print("Finished preprocessing dataframe.") + freq_inps = [freq_map(freq)] * len(new_inputs) + _, full_forecast = self.forecast(new_inputs, + freq=freq_inps, + normalize=normalize, + window_size=window_size) + if verbose: + print("Finished forecasting.") + fcst_df = make_future_dataframe( + uids=uids, + last_times=df_sorted.groupby("unique_id")["ds"].tail(1), + h=self.horizon_len, + freq=freq, ) - new_inputs, uids = zip(*results) - if verbose: - print("Finished preprocessing dataframe.") - freq_inps = [freq_map(freq)] * len(new_inputs) - _, full_forecast = self.forecast(new_inputs, - freq=freq_inps, - normalize=normalize, - window_size=window_size) - if verbose: - print("Finished forecasting.") - fcst_df = make_future_dataframe( - uids=uids, - last_times=df_sorted.groupby("unique_id")["ds"].tail(1), - h=self.horizon_len, - freq=freq, - ) - fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1) - - for i, q in enumerate(self.quantiles): - q_col = f"{model_name}-q-{q}" - fcst_df[q_col] = full_forecast[:, 0:self.horizon_len, - 1 + i].reshape(-1, 1) - if q == 0.5: - fcst_df[model_name] = fcst_df[q_col] - logging.info("Finished creating output dataframe.") - return fcst_df + fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, + 0].reshape(-1, 1) + + for i, q in enumerate(self.quantiles): + q_col = f"{model_name}-q-{q}" + fcst_df[q_col] = full_forecast[:, 0:self.horizon_len, + 1 + i].reshape(-1, 1) + if q == 0.5: + fcst_df[model_name] = fcst_df[q_col] + logging.info("Finished creating output dataframe.") + return fcst_df