-
Notifications
You must be signed in to change notification settings - Fork 1.3k
8564 fourier positional encoding #8570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
WalkthroughAdds Fourier-based positional embedding support to PatchEmbeddingBlock and a new optional pos_embed_kwargs parameter forwarded to position-embedding builders. SUPPORTED_POS_EMBEDDING_TYPES now includes "fourier". PatchEmbeddingBlock uses build_fourier_position_embedding when pos_embed_type="fourier" and forwards kwargs to sincos/fourier builders; pos_embed_kwargs defaults to {}. Implements build_fourier_position_embedding in pos_embed_utils.py, exports it via all, and returns a fixed non-trainable nn.Parameter of shape [1, N, embed_dim]. Tests added to validate Fourier embeddings and related invalid-argument cases. Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Pre-merge checks❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (7)
monai/networks/blocks/pos_embed_utils.py (3)
38-53
: Docstring polish and type consistency.Fix typos, align types with signature, and document raised errors.
- Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension, + Builds an (Anisotropic) Fourier feature‑based positional embedding based on the given grid size, embed dimension, @@ - Args: - grid_size (List[int]): The size of the grid in each spatial dimension. + Args: + grid_size (Union[int, Sequence[int]]): The size of the grid in each spatial dimension. @@ - spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). - scales (List[float]): The scale for every spatial dimension. If a single float is provided, - the same scale is used for all dimensions. + spatial_dims (int): The number of spatial dimensions (e.g., 2 for 2D, 3 for 3D). + scales (Union[float, Sequence[float]]): Per‑dimension scale(s). If a single float is provided, + the same scale is used for all dimensions. @@ - Returns: - pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter. + Returns: + torch.nn.Parameter: The Fourier feature position embedding as a fixed parameter. + + Raises: + AssertionError: If `embed_dim` is not divisible by 2. + ValueError: If `scales` length does not equal `spatial_dims`.
21-21
: Sort__all__
(RUF022).-__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"] +__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]
64-66
: Optional: determinism and device/dtype alignment.
- Consider accepting
generator: Optional[torch.Generator] = None
and callingtorch.normal(..., generator=generator)
for reproducibility.- Create
positions
ongaussians.device
/dtype to avoid implicit casts.monai/networks/blocks/patchembedding.py (2)
112-133
: Avoid duplicatinggrid_size
construction.Build once and reuse for both branches.
- if self.pos_embed_type == "none": + if self.pos_embed_type == "none": pass elif self.pos_embed_type == "learnable": trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) elif self.pos_embed_type == "sincos": - grid_size = [] - for in_size, pa_size in zip(img_size, patch_size): - grid_size.append(in_size // pa_size) - + grid_size = [in_size // pa_size for in_size, pa_size in zip(img_size, patch_size)] self.position_embeddings = build_sincos_position_embedding( grid_size, hidden_size, spatial_dims, **pos_embed_kwargs ) elif self.pos_embed_type == "fourier": - grid_size = [] - for in_size, pa_size in zip(img_size, patch_size): - grid_size.append(in_size // pa_size) - + grid_size = [in_size // pa_size for in_size, pa_size in zip(img_size, patch_size)] self.position_embeddings = build_fourier_position_embedding( grid_size, hidden_size, spatial_dims, **pos_embed_kwargs )
119-126
: Filter kwargs per embedding type to prevent accidentalTypeError
.Guard against unexpected keys in
pos_embed_kwargs
(e.g., passingscales
tosincos
).- self.position_embeddings = build_sincos_position_embedding( - grid_size, hidden_size, spatial_dims, **pos_embed_kwargs - ) + allowed = {"temperature"} + kw = {k: v for k, v in pos_embed_kwargs.items() if k in allowed} + self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims, **kw)And analogously for the Fourier branch:
- self.position_embeddings = build_fourier_position_embedding( - grid_size, hidden_size, spatial_dims, **pos_embed_kwargs - ) + allowed = {"scales"} + kw = {k: v for k, v in pos_embed_kwargs.items() if k in allowed} + self.position_embeddings = build_fourier_position_embedding(grid_size, hidden_size, spatial_dims, **kw)tests/networks/blocks/test_patchembedding.py (2)
90-102
: Fourier positional embedding test: good.Consistently checks non‑trainable param. Consider also asserting shape equals expected for completeness.
90-102
: Optional: add an oddhidden_size
failure test for Fourier.After enforcing
embed_dim % 2 == 0
, add a negative test (e.g.,hidden_size=95
) to assert it raises.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (3)
monai/networks/blocks/patchembedding.py
(5 hunks)monai/networks/blocks/pos_embed_utils.py
(2 hunks)tests/networks/blocks/test_patchembedding.py
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/blocks/test_patchembedding.py
monai/networks/blocks/patchembedding.py
monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py
21-21: __all__
is not sorted
Apply an isort-style sorting to __all__
(RUF022)
60-60: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (2)
monai/networks/blocks/patchembedding.py (2)
23-31
: Fourier import and type registration: looks good.
128-133
: No in-repo PatchEmbeddingBlock(pos_embed_type='fourier') calls found — confirm external/positional callers won't pass odd hidden_size.
DCO Remediation Commit for NabJa <[email protected]> I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 93c9354 I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 2794dec Signed-off-by: NabJa <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
monai/networks/blocks/pos_embed_utils.py (2)
63-68
: Fix scales validation to catch wrong-length lists and normalize dtype.Current check misses 1D lists of wrong length and will fail later with a RuntimeError during broadcast.
Apply:
- scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.float) - if scales.ndim > 1 and scales.ndim != spatial_dims: - raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims") - if scales.ndim == 0: - scales = scales.repeat(spatial_dims) + scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.get_default_dtype()) + if scales.ndim == 0: + scales = scales.repeat(spatial_dims) + elif scales.ndim == 1 and scales.numel() == spatial_dims: + pass + else: + raise ValueError(f"scales must be a float or 1D sequence of length spatial_dims={spatial_dims}")
58-61
: Relax divisibility: require even embed_dim, not multiples of 2*spatial_dims.Fourier features only need sin/cos pairing; enforcing divisibility by 2*spatial_dims is unnecessarily restrictive and will reject valid configs (e.g., 2D with embed_dim=130).
Apply:
- if embed_dim % (2 * spatial_dims) != 0: - raise AssertionError( - f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding" - ) + if embed_dim % 2 != 0: + raise AssertionError("Embed dimension must be even for Fourier position embedding")
🧹 Nitpick comments (4)
monai/networks/blocks/pos_embed_utils.py (3)
69-75
: Align dtype/device for numerics.Make gaussians/positions match scales’ dtype/device to avoid accidental dtype/device mismatches.
Apply:
- gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims)) + gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims), dtype=scales.dtype, device=scales.device) gaussians = gaussians * scales - positions = [torch.linspace(0, 1, x) for x in grid_size] + positions = [torch.linspace(0, 1, x, dtype=scales.dtype, device=scales.device) for x in grid_size] positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), dim=-1) positions = positions.flatten(end_dim=-2)
38-53
: Docstring nits: fix typo and document Raises.Tighten wording and add a brief Raises section for the new validations.
Apply:
- Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension, + Builds an (anisotropic) Fourier-feature positional encoding based on the given grid size, embed dimension, @@ - Returns: + Returns: pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter. + + Raises: + AssertionError: if embed_dim is not even. + ValueError: if grid_size length != spatial_dims or scales has invalid shape/length.
21-21
: Sort all for consistency.Minor style cleanup.
Apply:
-__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"] +__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]tests/networks/blocks/test_patchembedding.py (1)
90-102
: Also assert shape of Fourier positional embedding.Strengthen the test by validating N and hidden_size.
Apply:
def test_fourier_pos_embed(self): net = PatchEmbeddingBlock( @@ dropout_rate=0.5, ) - self.assertEqual(net.position_embeddings.requires_grad, False) + self.assertEqual(net.position_embeddings.requires_grad, False) + # (32/8)^3 = 64 tokens, hidden_size=96 + self.assertEqual(tuple(net.position_embeddings.shape), (1, 64, 96))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (2)
monai/networks/blocks/pos_embed_utils.py
(2 hunks)tests/networks/blocks/test_patchembedding.py
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/blocks/test_patchembedding.py
monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py
21-21: __all__
is not sorted
Apply an isort-style sorting to __all__
(RUF022)
59-61: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (2)
monai/networks/blocks/pos_embed_utils.py (1)
72-79
: Good: correct use of dim kwarg (not axis) and flattening before matmul.This prevents runtime errors seen with NumPy-style keywords.
tests/networks/blocks/test_patchembedding.py (1)
117-141
: LGTM: split assertRaises and use valid dropout_rate — verify tests locally.
Sandbox couldn't run pytest (pytest: command not found). Run locally:
pytest -q tests/networks/blocks/test_patchembedding.py::TestPatchEmbeddingBlock::test_fourier_pos_embed tests/networks/blocks/test_patchembedding.py::TestPatchEmbeddingBlock::test_ill_arg
I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: bfb5592 I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 3f5b125 Signed-off-by: NabJa <[email protected]>
Signed-off-by: NabJa <[email protected]>
Signed-off-by: NabJa <[email protected]>
Signed-off-by: NabJa <[email protected]>
Signed-off-by: NabJa <[email protected]>
Signed-off-by: NabJa <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/networks/blocks/pos_embed_utils.py (1)
60-64
: Over‑restrictive embed_dim check; use “even” only and ValueError.
Feature count only requires sin/cos pairing. Current check wrongly rejects valid configs (e.g., embed_dim=768 in 3D).- if embed_dim % (2 * spatial_dims) != 0: - raise AssertionError( - f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding" - ) + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even for Fourier position embedding")
🧹 Nitpick comments (5)
monai/networks/blocks/pos_embed_utils.py (5)
21-21
: Sort all to satisfy Ruff (RUF022).
Simple style fix.-__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"] +__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]
39-53
: Docstring typos/clarity.
Fix “Anistropic” and tighten parameter docs.- Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension, - spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant - points more distinguishable. - Reference: https://arxiv.org/abs/2509.02488 + Builds an anisotropic Fourier‑feature positional encoding for the given grid size, embedding dimension, + number of spatial dimensions, and scales. Larger scales increase frequency content, making distant points + more distinguishable. + Reference: https://arxiv.org/abs/2509.02488 @@ - grid_size (List[int]): The size of the grid in each spatial dimension. + grid_size (int | List[int]): The grid size per spatial dimension. @@ - spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). - scales (List[float]): The scale for every spatial dimension. If a single float is provided, - the same scale is used for all dimensions. + spatial_dims (int): Number of spatial dimensions (2 for 2D, 3 for 3D). + scales (float | List[float]): Per‑dimension scale(s). A scalar applies to all dimensions.
65-74
: Broaden scales parsing; unify dtype; simplify error text (TRY003).
Accept float/sequence/tensor, enforce 0‑D or 1‑D length match.- # Ensure scales is a tensor of shape (spatial_dims,) - if isinstance(scales, float): - scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float) - elif isinstance(scales, (list, tuple)): - if len(scales) != spatial_dims: - raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}") - scales_tensor = torch.tensor(scales, dtype=torch.float) - else: - raise TypeError(f"scales must be float or list of floats, got {type(scales)}") + # Normalize scales to tensor of shape (spatial_dims,) + scales_tensor = torch.as_tensor(scales, dtype=torch.float32) + if scales_tensor.ndim == 0: + scales_tensor = scales_tensor.repeat(spatial_dims) + elif scales_tensor.ndim == 1 and scales_tensor.numel() == spatial_dims: + pass + else: + raise ValueError(f"scales must be a float or 1D sequence of length spatial_dims={spatial_dims}")
75-77
: Ensure dtype consistency and simplify init.
Use randn with explicit dtype; fuse multiply.- gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims)) - gaussians = gaussians * scales_tensor + gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor
78-81
: Typo: “indeces” → “indices”; set dtype explicitly for coords.
Minor polish and consistent dtype.- position_indeces = [torch.linspace(0, 1, x) for x in grid_size_t] - positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1) + position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t] + positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
monai/networks/blocks/pos_embed_utils.py
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py
21-21: __all__
is not sorted
Apply an isort-style sorting to __all__
(RUF022)
58-58: Avoid specifying long messages outside the exception class
(TRY003)
61-63: Avoid specifying long messages outside the exception class
(TRY003)
70-70: Avoid specifying long messages outside the exception class
(TRY003)
73-73: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
monai/networks/blocks/pos_embed_utils.py (5)
35-38
: Signature + return type: LGTM.
Matches existing style of build_sincos_position_embedding.
55-59
: Good: explicit grid_size length validation.
Prevents late matmul shape errors.
82-82
: Projection math: LGTM.
Shapes: (N, D) @ (D, embed_dim/2) → (N, embed_dim/2).
84-87
: Output shape and freezing: LGTM.
Returns [1, N, embed_dim] and non‑trainable Parameter.
39-64
: Shorten error strings to appease TRY003.
Above diffs already shrink messages; no further changes needed if applied.Please run linters to confirm TRY003 is silenced post‑changes.
Also applies to: 65-74
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/networks/blocks/pos_embed_utils.py (1)
64-73
: Hardenscales
handling (accept scalars/tensors; validate length explicitly).Use
torch.as_tensor
, support ints/tensors, and validate ndim/length to avoid silent broadcast bugs.- # Ensure scales is a tensor of shape (spatial_dims,) - if isinstance(scales, float): - scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float) - elif isinstance(scales, (list, tuple)): - if len(scales) != spatial_dims: - raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}") - scales_tensor = torch.tensor(scales, dtype=torch.float) - else: - raise TypeError(f"scales must be float or list of floats, got {type(scales)}") + # Normalize scales to shape (spatial_dims,) + scales_tensor = torch.as_tensor(scales, dtype=torch.get_default_dtype()) + if scales_tensor.ndim == 0: + scales_tensor = scales_tensor.repeat(spatial_dims) + elif not (scales_tensor.ndim == 1 and scales_tensor.numel() == spatial_dims): + raise ValueError(f"scales must be a scalar or 1D sequence of length spatial_dims={spatial_dims}")
🧹 Nitpick comments (6)
monai/networks/blocks/pos_embed_utils.py (6)
39-44
: Fix typos and wording in docstring header.Use “Anisotropic” consistently.
- Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension, + Builds an (anisotropic) Fourier-feature position embedding based on the given grid size, embed dimension, @@ - Position embedding is made anistropic by allowing setting different scales for each spatial dimension. + The embedding can be made anisotropic by allowing different scales per spatial dimension.
45-54
: Add Raises section (per guidelines) and document determinism.Docstring should enumerate errors and note randomness.
Args: grid_size (int | List[int]): The size of the grid in each spatial dimension. embed_dim (int): The dimension of the embedding. spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). scales (float | List[float]): The scale for every spatial dimension. If a single float is provided, the same scale is used for all dimensions. Returns: pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter. + + Raises: + ValueError: If `len(grid_size) != spatial_dims`, if `embed_dim` is odd, or if `scales` length mismatches. + TypeError: If `scales` is not a scalar or a 1D sequence of numbers. + + Note: + The projection matrix is sampled from a Gaussian; set global seeds (e.g., via MONAI `set_determinism`) + or pass a torch.Generator (see suggested signature change below) to make results reproducible.
56-60
: Tighten error message (ruff TRY003).Shorten message to satisfy the linter.
- if len(grid_size_t) != spatial_dims: - raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.") + if len(grid_size_t) != spatial_dims: + raise ValueError("grid_size length must equal spatial_dims")
61-63
: Even-dimension check: OK; shorten message (ruff).- if embed_dim % 2 != 0: - raise ValueError("embed_dim must be even for Fourier position embedding") + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even")
74-74
: Optional: expose RNG for reproducibility.Allow a
generator
to control randomness; use default dtype for consistency.- gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor + gaussians = torch.randn( + embed_dim // 2, spatial_dims, dtype=torch.get_default_dtype(), generator=generator + ) * scales_tensorAlso update the signature:
-def build_fourier_position_embedding( - grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0 -) -> torch.nn.Parameter: +def build_fourier_position_embedding( + grid_size: Union[int, List[int]], + embed_dim: int, + spatial_dims: int = 3, + scales: Union[float, List[float]] = 1.0, + generator: Optional[torch.Generator] = None, +) -> torch.nn.Parameter:Add import (outside this hunk):
from typing import Optional
76-79
: Minor: honor global default dtype.Keep dtype consistent with rest of pipeline.
- position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t] + position_indices = [torch.linspace(0, 1, x, dtype=torch.get_default_dtype()) for x in grid_size_t]
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
monai/networks/blocks/pos_embed_utils.py
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py
59-59: Avoid specifying long messages outside the exception class
(TRY003)
62-62: Avoid specifying long messages outside the exception class
(TRY003)
69-69: Avoid specifying long messages outside the exception class
(TRY003)
72-72: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
🔇 Additional comments (3)
monai/networks/blocks/pos_embed_utils.py (3)
21-21
: Export looks good.Publicly exposing
build_fourier_position_embedding
via__all__
is appropriate.
80-81
: Projection math: LGTM.Shapes check out: [(N, S) @ (S, D/2)] -> (N, D/2).
82-84
: Parameter creation: LGTM.Non-trainable [1, N, D] aligns with existing sin-cos builder.
Consider adding a unit test asserting determinism when a global seed (or
generator
) is set: same inputs → identical embeddings.
Fixes #8564 .
Description
Add Fourier feature positional encodings to
PatchEmbeddingBlock
. It has been shown, that Fourier feature positional encodings are better suited for Anistropic images and videos: https://arxiv.org/abs/2509.02488Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.