Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,8 @@ point_forecast, quantile_forecast = model.forecast(
point_forecast.shape # (2, 12)
quantile_forecast.shape # (2, 12, 10): mean, then 10th to 90th quantiles.
```
### Training-time Patch Masking (per paper)

To ensure the model sees all effective context lengths during training, apply a random masking strategy to the first input patch of each time series in a batch. Let `p` be the input patch length. For each series, sample `r ∈ {0, 1, …, p−1}` and set the first `r` positions of the first patch as masked (ignored by the model). This starts masking from the beginning of the context window and exposes every context length from 1 up to the maximum training context.

Below is a compact PyTorch reference for sampling and applying the mask. In a full trainer, also propagate this as a padding/attention mask so the model does not attend to masked positions.
35 changes: 35 additions & 0 deletions src/timesfm/train_utils/masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

from __future__ import annotations
import torch

def sample_first_patch_mask(patch_len: int, batch_size: int, device=None) -> torch.Tensor:
"""
Training-time random patch masking (per TimesFM paper):
- For each series in the batch, sample r ~ Uniform{0, 1, ..., patch_len-1}
- Set m[0:r] = 1 (masked), m[r:] = 0 for the FIRST input patch only
Returns:
mask: Bool tensor of shape [batch_size, patch_len], True where positions are masked.
"""
if patch_len <= 0 or batch_size <= 0:
raise ValueError("patch_len and batch_size must be positive.")
r = torch.randint(low=0, high=patch_len, size=(batch_size,), device=device)
idx = torch.arange(patch_len, device=device).unsqueeze(0) # [1, P]
mask = idx < r.unsqueeze(1) # [B, P] True where masked
return mask

def apply_mask_to_first_patch(x_patched: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
x_patched: [B, num_patches, patch_len] float tensor
mask : [B, patch_len] bool tensor from sample_first_patch_mask
Zeroes masked positions in the FIRST patch. In a full trainer you would also
carry this as an attention/padding mask so those tokens are ignored.
"""
if x_patched.ndim != 3:
raise ValueError("x_patched must be [B, num_patches, patch_len].")
B, num_patches, P = x_patched.shape
if mask.shape != (B, P):
raise ValueError("mask must be [B, patch_len].")
x = x_patched.clone()
# broadcast mask to [B, 1, P] and zero masked values in the first patch only
x[:, 0, :] = x[:, 0, :].masked_fill(mask, 0.0)
return x
23 changes: 23 additions & 0 deletions v1/tests/test_training_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# tests/test_training_mask.py
import torch
from timesfm.train_utils.masking import sample_first_patch_mask, apply_mask_to_first_patch

def test_mask_shape_and_bounds():
B, P = 8, 16
m = sample_first_patch_mask(P, B, device='cpu')
assert m.shape == (B, P)
assert m.dtype == torch.bool
# Each row should look like [True x r] + [False x (P-r)], i.e., no False->True transitions
diffs = m[:, 1:].int() - m[:, :-1].int()
assert not (diffs == 1).any().item()

def test_apply_mask():
B, num_patches, P = 4, 3, 8
x = torch.ones(B, num_patches, P)
m = torch.zeros(B, P, dtype=torch.bool)
m[:, :3] = True # mask first 3 positions in first patch
y = apply_mask_to_first_patch(x, m)
assert torch.allclose(y[:, 0, :3], torch.zeros(B, 3))
assert torch.allclose(y[:, 0, 3:], torch.ones(B, P-3))
# other patches untouched
assert torch.allclose(y[:, 1:, :], torch.ones(B, num_patches-1, P))