Skip to content

Conversation

NabJa
Copy link
Contributor

@NabJa NabJa commented Sep 15, 2025

Fixes #8564 .

Description

Add Fourier feature positional encodings to PatchEmbeddingBlock. It has been shown, that Fourier feature positional encodings are better suited for Anistropic images and videos: https://arxiv.org/abs/2509.02488

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Copy link
Contributor

coderabbitai bot commented Sep 15, 2025

Walkthrough

Adds Fourier-based positional embedding support to PatchEmbeddingBlock and a new optional pos_embed_kwargs parameter forwarded to position-embedding builders. SUPPORTED_POS_EMBEDDING_TYPES now includes "fourier". PatchEmbeddingBlock uses build_fourier_position_embedding when pos_embed_type="fourier" and forwards kwargs to sincos/fourier builders; pos_embed_kwargs defaults to {}. Implements build_fourier_position_embedding in pos_embed_utils.py, exports it via all, and returns a fixed non-trainable nn.Parameter of shape [1, N, embed_dim]. Tests added to validate Fourier embeddings and related invalid-argument cases.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Pre-merge checks

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "8564 fourier positional encoding" directly references the issue number and the main feature being implemented. The changes add Fourier-based positional embedding support to PatchEmbeddingBlock, which matches the title's description of fourier positional encoding functionality. The title is concise and clearly identifies the primary change being made.
Linked Issues Check ✅ Passed The implementation successfully addresses the core requirements from issue #8564. The code adds Fourier feature positional encodings through a new build_fourier_position_embedding function that supports anisotropic scaling via the scales parameter, integrates into PatchEmbeddingBlock with the new "fourier" pos_embed_type, and includes comprehensive test coverage. The implementation provides both isotropic and anisotropic variants as requested through the configurable scales parameter.
Out of Scope Changes Check ✅ Passed All changes are directly related to implementing Fourier positional encodings as specified in issue #8564. The modifications include adding the new embedding function, integrating it into PatchEmbeddingBlock, extending supported embedding types, and adding appropriate test coverage. No unrelated or out-of-scope changes are present in the changeset.
Description Check ✅ Passed The pull request description follows the repository template: it includes "Fixes #8564", a concise Description explaining the addition of Fourier-feature positional encodings with an arXiv reference, and a populated "Types of changes" checklist, so it is sufficient for reviewers to understand intent and scope; however there is a minor spelling error ("Anistropic") and the claims that docs/tests were run lack attached evidence (CI logs or build output).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (7)
monai/networks/blocks/pos_embed_utils.py (3)

38-53: Docstring polish and type consistency.

Fix typos, align types with signature, and document raised errors.

-    Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
+    Builds an (Anisotropic) Fourier feature‑based positional embedding based on the given grid size, embed dimension,
@@
-    Args:
-        grid_size (List[int]): The size of the grid in each spatial dimension.
+    Args:
+        grid_size (Union[int, Sequence[int]]): The size of the grid in each spatial dimension.
@@
-        spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
-        scales (List[float]): The scale for every spatial dimension. If a single float is provided,
-                              the same scale is used for all dimensions.
+        spatial_dims (int): The number of spatial dimensions (e.g., 2 for 2D, 3 for 3D).
+        scales (Union[float, Sequence[float]]): Per‑dimension scale(s). If a single float is provided,
+            the same scale is used for all dimensions.
@@
-    Returns:
-        pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
+    Returns:
+        torch.nn.Parameter: The Fourier feature position embedding as a fixed parameter.
+
+    Raises:
+        AssertionError: If `embed_dim` is not divisible by 2.
+        ValueError: If `scales` length does not equal `spatial_dims`.

21-21: Sort __all__ (RUF022).

-__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"]
+__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]

64-66: Optional: determinism and device/dtype alignment.

  • Consider accepting generator: Optional[torch.Generator] = None and calling torch.normal(..., generator=generator) for reproducibility.
  • Create positions on gaussians.device/dtype to avoid implicit casts.
monai/networks/blocks/patchembedding.py (2)

112-133: Avoid duplicating grid_size construction.

Build once and reuse for both branches.

-        if self.pos_embed_type == "none":
+        if self.pos_embed_type == "none":
             pass
         elif self.pos_embed_type == "learnable":
             trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
         elif self.pos_embed_type == "sincos":
-            grid_size = []
-            for in_size, pa_size in zip(img_size, patch_size):
-                grid_size.append(in_size // pa_size)
-
+            grid_size = [in_size // pa_size for in_size, pa_size in zip(img_size, patch_size)]
             self.position_embeddings = build_sincos_position_embedding(
                 grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
             )
         elif self.pos_embed_type == "fourier":
-            grid_size = []
-            for in_size, pa_size in zip(img_size, patch_size):
-                grid_size.append(in_size // pa_size)
-
+            grid_size = [in_size // pa_size for in_size, pa_size in zip(img_size, patch_size)]
             self.position_embeddings = build_fourier_position_embedding(
                 grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
             )

119-126: Filter kwargs per embedding type to prevent accidental TypeError.

Guard against unexpected keys in pos_embed_kwargs (e.g., passing scales to sincos).

-            self.position_embeddings = build_sincos_position_embedding(
-                grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
-            )
+            allowed = {"temperature"}
+            kw = {k: v for k, v in pos_embed_kwargs.items() if k in allowed}
+            self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims, **kw)

And analogously for the Fourier branch:

-            self.position_embeddings = build_fourier_position_embedding(
-                grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
-            )
+            allowed = {"scales"}
+            kw = {k: v for k, v in pos_embed_kwargs.items() if k in allowed}
+            self.position_embeddings = build_fourier_position_embedding(grid_size, hidden_size, spatial_dims, **kw)
tests/networks/blocks/test_patchembedding.py (2)

90-102: Fourier positional embedding test: good.

Consistently checks non‑trainable param. Consider also asserting shape equals expected for completeness.


90-102: Optional: add an odd hidden_size failure test for Fourier.

After enforcing embed_dim % 2 == 0, add a negative test (e.g., hidden_size=95) to assert it raises.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between fd13c1b and 2794dec.

📒 Files selected for processing (3)
  • monai/networks/blocks/patchembedding.py (5 hunks)
  • monai/networks/blocks/pos_embed_utils.py (2 hunks)
  • tests/networks/blocks/test_patchembedding.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/networks/blocks/test_patchembedding.py
  • monai/networks/blocks/patchembedding.py
  • monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py

21-21: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


60-60: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (2)
monai/networks/blocks/patchembedding.py (2)

23-31: Fourier import and type registration: looks good.


128-133: No in-repo PatchEmbeddingBlock(pos_embed_type='fourier') calls found — confirm external/positional callers won't pass odd hidden_size.

DCO Remediation Commit for NabJa <[email protected]>

I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 93c9354
I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 2794dec

Signed-off-by: NabJa <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
monai/networks/blocks/pos_embed_utils.py (2)

63-68: Fix scales validation to catch wrong-length lists and normalize dtype.

Current check misses 1D lists of wrong length and will fail later with a RuntimeError during broadcast.

Apply:

-    scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.float)
-    if scales.ndim > 1 and scales.ndim != spatial_dims:
-        raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims")
-    if scales.ndim == 0:
-        scales = scales.repeat(spatial_dims)
+    scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.get_default_dtype())
+    if scales.ndim == 0:
+        scales = scales.repeat(spatial_dims)
+    elif scales.ndim == 1 and scales.numel() == spatial_dims:
+        pass
+    else:
+        raise ValueError(f"scales must be a float or 1D sequence of length spatial_dims={spatial_dims}")

58-61: Relax divisibility: require even embed_dim, not multiples of 2*spatial_dims.

Fourier features only need sin/cos pairing; enforcing divisibility by 2*spatial_dims is unnecessarily restrictive and will reject valid configs (e.g., 2D with embed_dim=130).

Apply:

-    if embed_dim % (2 * spatial_dims) != 0:
-        raise AssertionError(
-            f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding"
-        )
+    if embed_dim % 2 != 0:
+        raise AssertionError("Embed dimension must be even for Fourier position embedding")
🧹 Nitpick comments (4)
monai/networks/blocks/pos_embed_utils.py (3)

69-75: Align dtype/device for numerics.

Make gaussians/positions match scales’ dtype/device to avoid accidental dtype/device mismatches.

Apply:

-    gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
+    gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims), dtype=scales.dtype, device=scales.device)
     gaussians = gaussians * scales

-    positions = [torch.linspace(0, 1, x) for x in grid_size]
+    positions = [torch.linspace(0, 1, x, dtype=scales.dtype, device=scales.device) for x in grid_size]
     positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), dim=-1)
     positions = positions.flatten(end_dim=-2)

38-53: Docstring nits: fix typo and document Raises.

Tighten wording and add a brief Raises section for the new validations.

Apply:

-    Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
+    Builds an (anisotropic) Fourier-feature positional encoding based on the given grid size, embed dimension,
@@
-    Returns:
+    Returns:
         pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
+
+    Raises:
+        AssertionError: if embed_dim is not even.
+        ValueError: if grid_size length != spatial_dims or scales has invalid shape/length.

21-21: Sort all for consistency.

Minor style cleanup.

Apply:

-__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"]
+__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]
tests/networks/blocks/test_patchembedding.py (1)

90-102: Also assert shape of Fourier positional embedding.

Strengthen the test by validating N and hidden_size.

Apply:

     def test_fourier_pos_embed(self):
         net = PatchEmbeddingBlock(
@@
             dropout_rate=0.5,
         )
 
-        self.assertEqual(net.position_embeddings.requires_grad, False)
+        self.assertEqual(net.position_embeddings.requires_grad, False)
+        # (32/8)^3 = 64 tokens, hidden_size=96
+        self.assertEqual(tuple(net.position_embeddings.shape), (1, 64, 96))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 2794dec and 3f5b125.

📒 Files selected for processing (2)
  • monai/networks/blocks/pos_embed_utils.py (2 hunks)
  • tests/networks/blocks/test_patchembedding.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/networks/blocks/test_patchembedding.py
  • monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py

21-21: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


59-61: Avoid specifying long messages outside the exception class

(TRY003)


65-65: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (2)
monai/networks/blocks/pos_embed_utils.py (1)

72-79: Good: correct use of dim kwarg (not axis) and flattening before matmul.

This prevents runtime errors seen with NumPy-style keywords.

tests/networks/blocks/test_patchembedding.py (1)

117-141: LGTM: split assertRaises and use valid dropout_rate — verify tests locally.
Sandbox couldn't run pytest (pytest: command not found). Run locally:
pytest -q tests/networks/blocks/test_patchembedding.py::TestPatchEmbeddingBlock::test_fourier_pos_embed tests/networks/blocks/test_patchembedding.py::TestPatchEmbeddingBlock::test_ill_arg

@NabJa NabJa marked this pull request as draft September 15, 2025 11:52
I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: bfb5592
I, NabJa <[email protected]>, hereby add my Signed-off-by to this commit: 3f5b125

Signed-off-by: NabJa <[email protected]>
Signed-off-by: NabJa <[email protected]>
Signed-off-by: NabJa <[email protected]>
@NabJa NabJa marked this pull request as ready for review September 16, 2025 07:29
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
monai/networks/blocks/pos_embed_utils.py (1)

60-64: Over‑restrictive embed_dim check; use “even” only and ValueError.
Feature count only requires sin/cos pairing. Current check wrongly rejects valid configs (e.g., embed_dim=768 in 3D).

-    if embed_dim % (2 * spatial_dims) != 0:
-        raise AssertionError(
-            f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding"
-        )
+    if embed_dim % 2 != 0:
+        raise ValueError("embed_dim must be even for Fourier position embedding")
🧹 Nitpick comments (5)
monai/networks/blocks/pos_embed_utils.py (5)

21-21: Sort all to satisfy Ruff (RUF022).
Simple style fix.

-__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"]
+__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]

39-53: Docstring typos/clarity.
Fix “Anistropic” and tighten parameter docs.

-    Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
-    spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
-    points more distinguishable.
-        Reference: https://arxiv.org/abs/2509.02488
+    Builds an anisotropic Fourier‑feature positional encoding for the given grid size, embedding dimension,
+    number of spatial dimensions, and scales. Larger scales increase frequency content, making distant points
+    more distinguishable.
+    Reference: https://arxiv.org/abs/2509.02488
@@
-        grid_size (List[int]): The size of the grid in each spatial dimension.
+        grid_size (int | List[int]): The grid size per spatial dimension.
@@
-        spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
-        scales (List[float]): The scale for every spatial dimension. If a single float is provided,
-                              the same scale is used for all dimensions.
+        spatial_dims (int): Number of spatial dimensions (2 for 2D, 3 for 3D).
+        scales (float | List[float]): Per‑dimension scale(s). A scalar applies to all dimensions.

65-74: Broaden scales parsing; unify dtype; simplify error text (TRY003).
Accept float/sequence/tensor, enforce 0‑D or 1‑D length match.

-    # Ensure scales is a tensor of shape (spatial_dims,)
-    if isinstance(scales, float):
-        scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float)
-    elif isinstance(scales, (list, tuple)):
-        if len(scales) != spatial_dims:
-            raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}")
-        scales_tensor = torch.tensor(scales, dtype=torch.float)
-    else:
-        raise TypeError(f"scales must be float or list of floats, got {type(scales)}")
+    # Normalize scales to tensor of shape (spatial_dims,)
+    scales_tensor = torch.as_tensor(scales, dtype=torch.float32)
+    if scales_tensor.ndim == 0:
+        scales_tensor = scales_tensor.repeat(spatial_dims)
+    elif scales_tensor.ndim == 1 and scales_tensor.numel() == spatial_dims:
+        pass
+    else:
+        raise ValueError(f"scales must be a float or 1D sequence of length spatial_dims={spatial_dims}")

75-77: Ensure dtype consistency and simplify init.
Use randn with explicit dtype; fuse multiply.

-    gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
-    gaussians = gaussians * scales_tensor
+    gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor

78-81: Typo: “indeces” → “indices”; set dtype explicitly for coords.
Minor polish and consistent dtype.

-    position_indeces = [torch.linspace(0, 1, x) for x in grid_size_t]
-    positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1)
+    position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t]
+    positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 3f5b125 and aabf175.

📒 Files selected for processing (1)
  • monai/networks/blocks/pos_embed_utils.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py

21-21: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


58-58: Avoid specifying long messages outside the exception class

(TRY003)


61-63: Avoid specifying long messages outside the exception class

(TRY003)


70-70: Avoid specifying long messages outside the exception class

(TRY003)


73-73: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
monai/networks/blocks/pos_embed_utils.py (5)

35-38: Signature + return type: LGTM.
Matches existing style of build_sincos_position_embedding.


55-59: Good: explicit grid_size length validation.
Prevents late matmul shape errors.


82-82: Projection math: LGTM.
Shapes: (N, D) @ (D, embed_dim/2) → (N, embed_dim/2).


84-87: Output shape and freezing: LGTM.
Returns [1, N, embed_dim] and non‑trainable Parameter.


39-64: Shorten error strings to appease TRY003.
Above diffs already shrink messages; no further changes needed if applied.

Please run linters to confirm TRY003 is silenced post‑changes.

Also applies to: 65-74

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
monai/networks/blocks/pos_embed_utils.py (1)

64-73: Harden scales handling (accept scalars/tensors; validate length explicitly).

Use torch.as_tensor, support ints/tensors, and validate ndim/length to avoid silent broadcast bugs.

-    # Ensure scales is a tensor of shape (spatial_dims,)
-    if isinstance(scales, float):
-        scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float)
-    elif isinstance(scales, (list, tuple)):
-        if len(scales) != spatial_dims:
-            raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}")
-        scales_tensor = torch.tensor(scales, dtype=torch.float)
-    else:
-        raise TypeError(f"scales must be float or list of floats, got {type(scales)}")
+    # Normalize scales to shape (spatial_dims,)
+    scales_tensor = torch.as_tensor(scales, dtype=torch.get_default_dtype())
+    if scales_tensor.ndim == 0:
+        scales_tensor = scales_tensor.repeat(spatial_dims)
+    elif not (scales_tensor.ndim == 1 and scales_tensor.numel() == spatial_dims):
+        raise ValueError(f"scales must be a scalar or 1D sequence of length spatial_dims={spatial_dims}")
🧹 Nitpick comments (6)
monai/networks/blocks/pos_embed_utils.py (6)

39-44: Fix typos and wording in docstring header.

Use “Anisotropic” consistently.

-    Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension,
+    Builds an (anisotropic) Fourier-feature position embedding based on the given grid size, embed dimension,
@@
-    Position embedding is made anistropic by allowing setting different scales for each spatial dimension.
+    The embedding can be made anisotropic by allowing different scales per spatial dimension.

45-54: Add Raises section (per guidelines) and document determinism.

Docstring should enumerate errors and note randomness.

     Args:
         grid_size (int | List[int]): The size of the grid in each spatial dimension.
         embed_dim (int): The dimension of the embedding.
         spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
         scales (float | List[float]): The scale for every spatial dimension. If a single float is provided,
                               the same scale is used for all dimensions.
 
     Returns:
         pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
+
+    Raises:
+        ValueError: If `len(grid_size) != spatial_dims`, if `embed_dim` is odd, or if `scales` length mismatches.
+        TypeError: If `scales` is not a scalar or a 1D sequence of numbers.
+
+    Note:
+        The projection matrix is sampled from a Gaussian; set global seeds (e.g., via MONAI `set_determinism`)
+        or pass a torch.Generator (see suggested signature change below) to make results reproducible.

56-60: Tighten error message (ruff TRY003).

Shorten message to satisfy the linter.

-    if len(grid_size_t) != spatial_dims:
-        raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.")
+    if len(grid_size_t) != spatial_dims:
+        raise ValueError("grid_size length must equal spatial_dims")

61-63: Even-dimension check: OK; shorten message (ruff).

-    if embed_dim % 2 != 0:
-        raise ValueError("embed_dim must be even for Fourier position embedding")
+    if embed_dim % 2 != 0:
+        raise ValueError("embed_dim must be even")

74-74: Optional: expose RNG for reproducibility.

Allow a generator to control randomness; use default dtype for consistency.

-    gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor
+    gaussians = torch.randn(
+        embed_dim // 2, spatial_dims, dtype=torch.get_default_dtype(), generator=generator
+    ) * scales_tensor

Also update the signature:

-def build_fourier_position_embedding(
-    grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
-) -> torch.nn.Parameter:
+def build_fourier_position_embedding(
+    grid_size: Union[int, List[int]],
+    embed_dim: int,
+    spatial_dims: int = 3,
+    scales: Union[float, List[float]] = 1.0,
+    generator: Optional[torch.Generator] = None,
+) -> torch.nn.Parameter:

Add import (outside this hunk):

from typing import Optional

76-79: Minor: honor global default dtype.

Keep dtype consistent with rest of pipeline.

-    position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t]
+    position_indices = [torch.linspace(0, 1, x, dtype=torch.get_default_dtype()) for x in grid_size_t]
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between aabf175 and 8a375b6.

📒 Files selected for processing (1)
  • monai/networks/blocks/pos_embed_utils.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/blocks/pos_embed_utils.py
🪛 Ruff (0.12.2)
monai/networks/blocks/pos_embed_utils.py

59-59: Avoid specifying long messages outside the exception class

(TRY003)


62-62: Avoid specifying long messages outside the exception class

(TRY003)


69-69: Avoid specifying long messages outside the exception class

(TRY003)


72-72: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
🔇 Additional comments (3)
monai/networks/blocks/pos_embed_utils.py (3)

21-21: Export looks good.

Publicly exposing build_fourier_position_embedding via __all__ is appropriate.


80-81: Projection math: LGTM.

Shapes check out: [(N, S) @ (S, D/2)] -> (N, D/2).


82-84: Parameter creation: LGTM.

Non-trainable [1, N, D] aligns with existing sin-cos builder.

Consider adding a unit test asserting determinism when a global seed (or generator) is set: same inputs → identical embeddings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

(Anisotropic) Fourier Feature Positional Encodings
1 participant