Skip to content

Conversation

eclipse0922
Copy link

@eclipse0922 eclipse0922 commented Sep 21, 2025

Fixes #3328 .

Description

A few sentences describing the changes proposed in this pull request.

This pull request introduces GenerateHeatmap and GenerateHeatmapd transforms for creating Gaussian heatmaps from landmark coordinates.
The input points are currently expected in ZYX order, but this can be changed to support XYZ if preferred.
The transforms support both batched (B, N, D) and non-batched (N, D) inputs.

Example notebooks are included for demonstration and will be removed before the PR is merged.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Adds a `GenerateHeatmap` transform to create gaussian response maps from landmark coordinates.
This transform is implemented for both array and dictionary-based workflows.
It enables the generation of heatmaps from landmark data, facilitating tasks
like landmark localization and visualization.
The transform supports 2D and 3D coordinates and offers options for controlling
the gaussian standard deviation, spatial shape, truncation, normalization, and data type.
Introduces a new interactive notebook demonstrating landmark to heatmap conversion using MONAI transforms.

This includes:
- A notebook with array and dictionary transform modes.
- A test suite for the `GenerateHeatmap` transform.

This enhancement enables users to visualize and interact with heatmap generation, facilitating a better understanding and application of the MONAI transforms.
Extends the `GenerateHeatmap` transform to support batched inputs,
allowing for more efficient processing of multiple landmark sets.

This change modifies the transform to handle inputs with a batch dimension (B, N, spatial_dims) in addition to single-point inputs (N, spatial_dims).
It also includes a demonstration of 3D heatmap generation using PyVista for visualization.
@eclipse0922 eclipse0922 marked this pull request as draft September 21, 2025 10:42
Streamlines the GenerateHeatmap and GenerateHeatmapd transforms for better usability and code clarity.

Specifically:
- Improves the input landmark array validation to provide a more descriptive error message.
- Removes example notebooks.

DCO Remediation Commit for sewon.jeon <[email protected]>

I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 8ef905b
I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 226bf90
I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 3097baf
I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 0072cb0

Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 0072cb0 to 25ceb7f Compare September 21, 2025 11:41
Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 4443705 to 9e33e7c Compare September 21, 2025 12:29
Copy link
Contributor

coderabbitai bot commented Sep 21, 2025

Warning

Rate limit exceeded

@eclipse0922 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 18 minutes and 30 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between fd4be38 and fc28c71.

📒 Files selected for processing (1)
  • monai/transforms/post/array.py (3 hunks)

Walkthrough

Adds a new public GenerateHeatmap Transform in monai/transforms/post/array.py that produces per-landmark Gaussian heatmaps for 2D/3D (batched) inputs. Supports scalar or per-dimension sigma (with broadcasting), optional truncation (local windowed evaluation), spatial_shape provided at construction or call, optional normalization, and output dtype resolution via get_equivalent_dtype. Adds GenerateHeatmapd (aliases GenerateHeatmapD / GenerateHeatmapDict) in monai/transforms/post/dictionary.py to apply GenerateHeatmap per key, normalize key mappings, infer per-key spatial_shape (including from a reference), convert outputs to a reference dtype/device via convert_to_dst_type, and propagate MetaTensor metadata. Updates all exports and adds comprehensive unit tests for both array and dictionary variants.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The description section still includes the template placeholder text and has not been cleaned up to only contain the actual change summary, indicating the template was not fully filled out. Please remove the placeholder line under “### Description” and ensure only relevant details about the heatmap transforms are included to fully meet the template requirements.
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly captures the core change by indicating the addition of heatmap generation transforms and aligns with the main feature introduced in the pull request.
Linked Issues Check ✅ Passed The new GenerateHeatmap and GenerateHeatmapd transforms fulfill all core issue requirements by providing reusable MONAI classes with customizable sigmas and spatial shapes, supporting batched and non-batched inputs, handling in_key/out_key mappings, and exposing the API for integration into pipelines as specified in issue #3328.
Out of Scope Changes Check ✅ Passed All changes align with the heatmap transform feature scope by adding the core GenerateHeatmap API, dictionary wrapper, associated tests, and export updates without introducing unrelated modifications.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 marked this pull request as ready for review September 21, 2025 14:27
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 60b58f5 to 54a81a5 Compare September 21, 2025 14:27
@eclipse0922
Copy link
Author

Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922
Copy link
Author

@coderabbitai review

@eclipse0922
Copy link
Author

@coderabbitai help

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (8)
monai/transforms/post/dictionary.py (2)

516-652: LGTM with minor improvements needed.

The implementation is solid. A few observations:

  1. The _update_spatial_metadata method has complex logic for determining spatial shape from different tensor dimensions. Consider simplifying or adding more explicit comments about the logic.

  2. Error messages could be extracted to constants for better maintainability (lines 573, 583, 598, 619-621, 633).

  3. The type hints could be more specific - Any is used extensively where more concrete types might be known.

Consider extracting error messages:

+_ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length."
+_ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided."
+_ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes."
+_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
+_ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)."
+_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."

 def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:
     if heatmap_keys is None:
         return tuple(f"{key}_heatmap" for key in self.keys)
     keys_tuple = ensure_tuple(heatmap_keys)
     if len(keys_tuple) == 1 and len(self.keys) > 1:
         keys_tuple = keys_tuple * len(self.keys)
     if len(keys_tuple) != len(self.keys):
-        raise ValueError("heatmap_keys length must match keys length.")
+        raise ValueError(_ERR_HEATMAP_KEYS_LEN)
     return keys_tuple

636-650: Simplify spatial metadata update logic.

The _update_spatial_metadata method has nested conditionals that make it hard to follow. The logic for distinguishing batched 2D from non-batched 3D is particularly complex.

Consider a clearer approach:

 def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
     """Update spatial metadata of heatmap based on its dimensions."""
-    # Update spatial_shape metadata based on heatmap dimensions
-    if heatmap.ndim == 5:  # 3D batched: (B, C, H, W, D)
-        heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
-    elif heatmap.ndim == 4:  # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
-        # Need to check if this is batched 2D or non-batched 3D
-        if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])):
-            # Non-batched 3D
-            heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
-        else:
-            # Batched 2D
-            heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
-    else:  # 2D non-batched: (C, H, W)
-        heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
+    # Determine if batched based on reference's batch dimension
+    ref_is_batched = len(reference.shape) > len(reference.meta.get("spatial_shape", [])) + 1
+    
+    if heatmap.ndim == 5:  # 3D batched
+        spatial_shape = heatmap.shape[2:]
+    elif heatmap.ndim == 4:
+        # Disambiguate: 2D batched vs 3D non-batched
+        spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:]
+    else:  # ndim == 3, 2D non-batched
+        spatial_shape = heatmap.shape[1:]
+    
+    heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
tests/transforms/test_generate_heatmapd.py (3)

101-101: Remove unused parameters.

Parameters expected_dtype and uses_ref are not used in the test method.

-def test_dict_with_reference_meta(self, _, points, params, expected_shape, expected_dtype, uses_ref):
+def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused):

151-151: Remove unused parameter.

Parameter expected_dtype is not used.

-def test_dict_batched_with_ref(self, _, points, params, expected_shape, expected_dtype):
+def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expected_dtype):

205-229: Document current behavior limitation.

The test acknowledges that MetaTensor points may inherit incorrect affine. This should be tracked as a known issue.

Should I create an issue to track the affine inheritance behavior when using MetaTensor points with reference images?

monai/transforms/post/array.py (3)

753-893: Well-implemented transform with room for minor improvements.

The GenerateHeatmap transform is well-structured. A few suggestions:

  1. The backend class attribute should be annotated with ClassVar (line 769).
  2. Consider extracting error messages to constants for maintainability.
  3. The _evaluate_gaussian method could benefit from a brief docstring.

Fix the class attribute annotation:

+from typing import ClassVar
 
 class GenerateHeatmap(Transform):
-    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
+    backend: ClassVar[list] = [TransformBackends.NUMPY, TransformBackends.TORCH]

863-865: Add boundary check optimization.

The _is_inside method could short-circuit on first failure.

 @staticmethod
 def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
-    return all(0 <= c < size for c, size in zip(center, bounds))
+    for c, size in zip(center, bounds):
+        if not (0 <= c < size):
+            return False
+    return True

881-892: Add docstring for clarity.

The _evaluate_gaussian method would benefit from documentation.

 def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
+    """Evaluate Gaussian at given coordinate shifts with specified sigmas.
+    
+    Args:
+        coord_shifts: Per-dimension coordinate offsets from center.
+        sigma: Per-dimension standard deviations.
+    
+    Returns:
+        Gaussian values at the specified coordinates.
+    """
     device = coord_shifts[0].device
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 53382d8 and 9f10dcf.

📒 Files selected for processing (4)
  • monai/transforms/post/array.py (3 hunks)
  • monai/transforms/post/dictionary.py (4 hunks)
  • tests/transforms/test_generate_heatmap.py (1 hunks)
  • tests/transforms/test_generate_heatmapd.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/transforms/post/dictionary.py
  • monai/transforms/post/array.py
  • tests/transforms/test_generate_heatmap.py
  • tests/transforms/test_generate_heatmapd.py
🪛 Ruff (0.13.1)
monai/transforms/post/dictionary.py

573-573: Avoid specifying long messages outside the exception class

(TRY003)


583-583: Avoid specifying long messages outside the exception class

(TRY003)


598-598: Avoid specifying long messages outside the exception class

(TRY003)


615-615: Avoid specifying long messages outside the exception class

(TRY003)


619-621: Avoid specifying long messages outside the exception class

(TRY003)


633-633: Avoid specifying long messages outside the exception class

(TRY003)

monai/transforms/post/array.py

769-769: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


781-781: Avoid specifying long messages outside the exception class

(TRY003)


785-785: Avoid specifying long messages outside the exception class

(TRY003)


788-788: Avoid specifying long messages outside the exception class

(TRY003)


802-804: Avoid specifying long messages outside the exception class

(TRY003)


808-808: Avoid specifying long messages outside the exception class

(TRY003)


847-847: Avoid specifying long messages outside the exception class

(TRY003)


853-853: Avoid specifying long messages outside the exception class

(TRY003)


861-861: Avoid specifying long messages outside the exception class

(TRY003)

tests/transforms/test_generate_heatmapd.py

56-56: Consider (1, *shape) instead of concatenation

Replace with (1, *shape)

(RUF005)


101-101: Unused method argument: expected_dtype

(ARG002)


101-101: Unused method argument: uses_ref

(ARG002)


151-151: Unused method argument: expected_dtype

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: packaging
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
🔇 Additional comments (2)
tests/transforms/test_generate_heatmap.py (2)

24-29: LGTM!

Clean helper function for finding peak coordinates.


233-240: Good error handling test coverage.

Testing multiple invalid input scenarios is thorough.

Signed-off-by: sewon.jeon <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (12)
monai/transforms/post/array.py (5)

754-769: Clarify coordinate order and output shape in docstring.

Please document the expected coordinate order (e.g., ZYX/XY), units (voxels), and output shapes:

  • Non-batched: (N, H, W[, D])
  • Batched: (B, N, H, W[, D])
 class GenerateHeatmap(Transform):
-    """
-    Generate per-landmark gaussian response maps for 2D or 3D coordinates.
-    ...
-    """
+    """
+    Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.
+
+    Notes:
+        - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
+        - Output shape:
+            - Non-batched points (N, D): (N, H, W[, D])
+            - Batched points (B, N, D): (B, N, H, W[, D])
+        - Each channel corresponds to one landmark.
+
+    Args:
+        sigma: ...
+        ...
+
+    Raises:
+        ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.
+    """

796-844: Avoid full-channel scans for normalization; compute peak within the updated window.

This reduces O(HWD) scans per point to O(window) while maintaining behavior (one Gaussian per channel).

-                region = heatmap[b_idx, idx][window_slices]
-                gaussian = self._evaluate_gaussian(coord_shifts, sigma)
-                torch.maximum(region, gaussian, out=region)
-                if self.normalize:
-                    max_val = heatmap[b_idx, idx].max()
-                    if max_val.item() > 0:
-                        heatmap[b_idx, idx] /= max_val
+                region = heatmap[b_idx, idx][window_slices]
+                gaussian = self._evaluate_gaussian(coord_shifts, sigma)
+                updated = torch.maximum(region, gaussian)
+                # write back
+                region.copy_(updated)
+                if self.normalize:
+                    peak = updated.max()
+                    if peak.item() > 0:
+                        heatmap[b_idx, idx] /= peak

871-907: Compute Gaussian in float32 for numerical stability, then cast to target dtype.

Half precision can underflow/overflow for the exponent. Do the math in float32 and cast at the end.

-            coord_shifts.append(torch.arange(start, stop, device=device, dtype=self.torch_dtype) - float(c))
+            coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c))
-    def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
+    def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
         """
         Evaluate Gaussian at given coordinate shifts with specified sigmas.
         ...
         """
-        device = coord_shifts[0].device
-        shape = tuple(len(axis) for axis in coord_shifts)
-        if 0 in shape:
-            return torch.zeros(shape, dtype=self.torch_dtype, device=device)
-        exponent = torch.zeros(shape, dtype=self.torch_dtype, device=device)
+        device = coord_shifts[0].device
+        shape = tuple(len(axis) for axis in coord_shifts)
+        if 0 in shape:
+            return torch.zeros(shape, dtype=self.torch_dtype, device=device)
+        exponent = torch.zeros(shape, dtype=torch.float32, device=device)
         for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)):
-            scaled = (shift / float(sig)) ** 2
+            shift32 = shift.to(torch.float32)
+            scaled = (shift32 / float(sig)) ** 2
             reshape_shape = [1] * len(coord_shifts)
             reshape_shape[dim] = shift.numel()
             exponent += scaled.reshape(reshape_shape)
-        return torch.exp(-0.5 * exponent)
+        gauss = torch.exp(-0.5 * exponent)
+        return gauss.to(dtype=self.torch_dtype)

845-856: Permit integer spatial_shape broadcasting in the call doc and error text.

Consider clarifying in error text that a single int is broadcast across dims (already implemented).

-                raise ValueError("spatial_shape length must match spatial dimension of the landmarks.")
+                raise ValueError(
+                    "spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
+                )

772-795: Ruff TRY003: shorten exception messages or centralize them.

Minor lint: several raises have relatively long string literals. You can shorten or move to constants if you care about TRY003.

tests/transforms/test_generate_heatmapd.py (3)

56-56: Use tuple unpacking for clarity.

Prefer (1, *shape) over (1,) + shape.

-            (1,) + shape,
+            (1, *shape),

168-171: Generalize max over spatial dims without hardcoding.

Flatten spatial dims to compute per-(B,C) maxima.

-        max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0]
-        np.testing.assert_allclose(
-            max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5
-        )
+        hm2 = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], -1)
+        max_vals = hm2.max(dim=2)[0]
+        np.testing.assert_allclose(max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5)

205-229: Align affine to the reference image when provided.

Test name and comment imply ref affine should win. Recommend adjusting the transform to explicitly set heatmap.affine = reference.affine when a ref MetaTensor is used. See suggested change in dictionary.py.

monai/transforms/post/dictionary.py (4)

532-551: Broaden dtype parameter type to accept torch.dtype.

Matches the array transform signature and current tests.

-        dtype: np.dtype | type = np.float32,
+        dtype: np.dtype | torch.dtype | type = np.float32,

552-573: Ensure affine alignment to reference MetaTensor.

Explicitly copy the affine when reference is MetaTensor, in addition to spatial metadata.

             if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):
                 # Convert to match reference type and device while preserving heatmap's dtype
                 heatmap, _, _ = convert_to_dst_type(
                     heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None)
                 )
                 # Copy metadata if reference is MetaTensor
                 if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
+                    heatmap.affine = reference.affine
                     self._update_spatial_metadata(heatmap, reference)

516-523: Docstring: add output shapes and per-key behavior.

Document the default for heatmap_keys, shape inference precedence (static_shape > ref_image), and output shapes for batched/non-batched.


641-656: Simplify spatial_shape derivation.

You can reliably set spatial_shape as the trailing spatial dims regardless of batch presence:

-        ref_spatial_shape = reference.meta.get("spatial_shape", [])
-        ref_is_batched = len(reference.shape) > len(ref_spatial_shape) + 1
-
-        if heatmap.ndim == 5:  # 3D batched: (B, C, H, W, D)
-            spatial_shape = heatmap.shape[2:]
-        elif heatmap.ndim == 4:  # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
-            # Disambiguate: 2D batched vs 3D non-batched
-            spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:]
-        else:  # 2D non-batched: (C, H, W)
-            spatial_shape = heatmap.shape[1:]
+        # trailing dims after channel are spatial regardless of batch presence
+        spatial_shape = heatmap.shape[-(reference.ndim - 1):]
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 9f10dcf and aaf2833.

📒 Files selected for processing (3)
  • monai/transforms/post/array.py (4 hunks)
  • monai/transforms/post/dictionary.py (4 hunks)
  • tests/transforms/test_generate_heatmapd.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/transforms/post/array.py
  • monai/transforms/post/dictionary.py
  • tests/transforms/test_generate_heatmapd.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py

782-782: Avoid specifying long messages outside the exception class

(TRY003)


786-786: Avoid specifying long messages outside the exception class

(TRY003)


789-789: Avoid specifying long messages outside the exception class

(TRY003)


803-805: Avoid specifying long messages outside the exception class

(TRY003)


809-809: Avoid specifying long messages outside the exception class

(TRY003)


848-848: Avoid specifying long messages outside the exception class

(TRY003)


854-854: Avoid specifying long messages outside the exception class

(TRY003)


862-862: Avoid specifying long messages outside the exception class

(TRY003)

tests/transforms/test_generate_heatmapd.py

56-56: Consider (1, *shape) instead of concatenation

Replace with (1, *shape)

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
🔇 Additional comments (1)
monai/transforms/post/array.py (1)

42-49: Good use of dtype mapping for backend parity.

Using get_equivalent_dtype to consistently map numpy/torch dtypes is the right call here.

Signed-off-by: sewon.jeon <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (12)
tests/transforms/test_generate_heatmapd.py (4)

121-130: Add dtype assertions for static-shape cases.

You pass expected_dtype but only assert equality with numpy’s dtype for static cases; consider also asserting no NaN/Inf and max == 1.0 for parity with other tests.


131-136: Verify error message content.

Also assert the exception string contains a helpful hint (e.g., “Provide spatial_shape or ref_image_keys”) to lock behavior.


186-205: Add negative/mismatch cases for key lengths.

Please add tests for:

  • Mismatched heatmap_keys length → ValueError.
  • Mismatched ref_image_keys length → ValueError.
  • Per-key spatial_shape length mismatch → ValueError.

These hit the explicit validation branches.


206-231: Align test with current behavior: affine comes from reference.

The comment suggests the heatmap “may inherit from points”, but the implementation sets heatmap.affine from the reference. Prefer asserting that explicitly.

Apply this diff:

-        # Note: Currently the heatmap may inherit affine from points MetaTensor
-        # This test documents the current behavior
-        # Ideally, the heatmap should use the reference image's affine
+        # Heatmap should inherit affine from the reference image
+        assert_allclose(heatmap.affine, image.affine, type_test=False)
monai/transforms/post/dictionary.py (3)

516-527: Docstring lacks parameter/return details.

Briefly document keys, heatmap_keys/ref_image_keys broadcasting rules, dtype/device behavior, and raised exceptions to match the rest of this module.


559-581: Spatial metadata should use the resolved shape (not reference.ndim).

Relying on reference.ndim-1 to infer spatial dims can be brittle with unusual channel layouts/batch dims. You already have the resolved shape; pass it through.

Apply this diff (two locations):

@@
-                if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
-                    heatmap.affine = reference.affine
-                    self._update_spatial_metadata(heatmap, reference)
+                if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
+                    heatmap.affine = reference.affine
+                    self._update_spatial_metadata(heatmap, shape)

And update the helper:

@@
-    def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
-        """Update spatial metadata of heatmap based on its dimensions."""
-        # trailing dims after channel are spatial regardless of batch presence
-        spatial_shape = heatmap.shape[-(reference.ndim - 1) :]
-        heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
+    def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int, ...]) -> None:
+        """Set spatial_shape explicitly from resolved shape."""
+        heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)

624-636: Error path for invalid points: message clarity.

Consider appending the actual ndim found to the error string to help users debug malformed inputs.

monai/transforms/post/array.py (5)

756-775: Clarify coordinate order and output shape semantics.

Docstring is good; suggest explicitly stating that target_shape is (Z, Y, X) for 3D and (Y, X) for 2D, and output layout is channel-first per-point heatmaps (N, …) or (B, N, …).


825-846: Avoid device sync on normalization.

peak.item() forces host sync on GPU. Keep math on-device to reduce stalls.

Apply this diff:

-                if self.normalize:
-                    peak = updated.max()
-                    if peak.item() > 0:
-                        heatmap[b_idx, idx] /= peak
+                if self.normalize:
+                    peak = updated.amax()
+                    denom = torch.where(peak > 0, peak, torch.ones_like(peak))
+                    heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom

854-867: Spatial shape resolution UX.

Broadcasting a single int to D dims is handy; consider adding that behavior to the class docstring for discoverability.


896-920: Gaussian evaluation is correct; minor micro-opts possible.

Current implementation is clear; if needed, precompute reshape shapes outside the loop to shave overhead in hot paths.


787-797: Ruff TRY003: long messages in exceptions.

Linters flag long exception strings; optional to shorten or centralize messages, but not a blocker.

Also applies to: 810-817, 857-865, 873-873

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between aaf2833 and 1bf0850.

📒 Files selected for processing (3)
  • monai/transforms/post/array.py (4 hunks)
  • monai/transforms/post/dictionary.py (4 hunks)
  • tests/transforms/test_generate_heatmapd.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/transforms/post/dictionary.py
  • monai/transforms/post/array.py
  • tests/transforms/test_generate_heatmapd.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py

789-789: Avoid specifying long messages outside the exception class

(TRY003)


793-793: Avoid specifying long messages outside the exception class

(TRY003)


796-796: Avoid specifying long messages outside the exception class

(TRY003)


810-812: Avoid specifying long messages outside the exception class

(TRY003)


816-816: Avoid specifying long messages outside the exception class

(TRY003)


857-857: Avoid specifying long messages outside the exception class

(TRY003)


863-865: Avoid specifying long messages outside the exception class

(TRY003)


873-873: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (6)
tests/transforms/test_generate_heatmapd.py (1)

99-120: Solid coverage for reference-driven behavior.

The assertions on MetaTensor type, shape, affine propagation, and normalization are on point.

monai/transforms/post/dictionary.py (2)

592-601: Key broadcasting logic: good and consistent.

Tuplifying and broadcasting single entries to len(keys) matches established patterns in this file.


602-623: Shape normalization handles all call forms.

Accepts None, single shape, or per-key shapes robustly; clear error on length mismatch.

monai/transforms/post/array.py (3)

787-804: Dtype mapping via get_equivalent_dtype: LGTM.

Robust handling for numpy/torch dtypes; positive checks for sigma/truncated are correct.


803-813: Input validation: good guardrails.

Batched and non-batched shapes validated with clear messaging; 2D/3D restriction enforced.


850-852: Type/device conversion: correct and consistent.

Returning to the original container type with appropriate dtype is consistent with MONAI conventions.

@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from bce6263 to 40c492d Compare September 26, 2025 15:06
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
monai/transforms/post/array.py (1)

754-922: Heatmap generation looks correct; tighten normalization and minor polish.

The algorithm and I/O handling are solid. One tweak: normalize using the channel’s amax (post-write) instead of the window’s amax to be robust and clearer.

Apply this diff:

@@
-                if self.normalize:
-                    peak = updated.amax()
-                    denom = torch.where(peak > 0, peak, torch.ones_like(peak))
-                    heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom
+                if self.normalize:
+                    peak = heatmap[b_idx, idx].amax()
+                    denom = torch.where(peak > 0, peak, torch.ones_like(peak))
+                    heatmap[b_idx, idx].div_(denom)

Optional:

  • Factor exception messages into class-level constants (like in GenerateHeatmapd) to satisfy Ruff TRY003. As per coding guidelines and static analysis hints.
monai/transforms/post/dictionary.py (1)

516-679: Fix sigma doc and validate static_shape rank

  • Doc: change “sequence matching number of points” → “sequence matching the number of spatial dimensions.”
  • In _determine_shape, convert points first, compute spatial_dims, then if static_shape is set raise ValueError when len(static_shape) != spatial_dims.
@@ class GenerateHeatmapd(MapTransform):
-        sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points.
+        sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number of spatial dimensions.
@@ def _determine_shape(
-        if static_shape is not None:
-            return static_shape
+        points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
+        if points_t.ndim not in (2, 3):
+            raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")
+        spatial_dims = int(points_t.shape[-1])
+        if static_shape is not None:
+            if len(static_shape) != spatial_dims:
+                raise ValueError(f"Provided static spatial_shape has {len(static_shape)} dims; expected {spatial_dims}.")
+            return static_shape
         if ref_key is not None and ref_key in data:
             return self._shape_from_reference(data[ref_key], spatial_dims)
         raise ValueError(self._ERR_NO_SHAPE)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1bf0850 and 40c492d.

📒 Files selected for processing (3)
  • monai/transforms/post/array.py (4 hunks)
  • monai/transforms/post/dictionary.py (4 hunks)
  • tests/transforms/test_generate_heatmapd.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/transforms/test_generate_heatmapd.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/transforms/post/array.py
  • monai/transforms/post/dictionary.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py

791-791: Avoid specifying long messages outside the exception class

(TRY003)


795-795: Avoid specifying long messages outside the exception class

(TRY003)


798-798: Avoid specifying long messages outside the exception class

(TRY003)


812-814: Avoid specifying long messages outside the exception class

(TRY003)


818-818: Avoid specifying long messages outside the exception class

(TRY003)


859-859: Avoid specifying long messages outside the exception class

(TRY003)


865-867: Avoid specifying long messages outside the exception class

(TRY003)


875-875: Avoid specifying long messages outside the exception class

(TRY003)

monai/transforms/post/dictionary.py

654-654: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (7)
monai/transforms/post/array.py (3)

19-19: Good: backend typing clarified.

Using ClassVar for backend is consistent with MONAI style.


42-49: Good: dtype and tensor utils imported where needed.

Imports align with new dtype handling in GenerateHeatmap.


65-65: Public export added.

Including GenerateHeatmap in all looks correct.

monai/transforms/post/dictionary.py (4)

38-38: Import of array-level transform is correct.

GenerateHeatmap is referenced directly and used as the engine.


52-52: Type/device alignment hook imported.

convert_to_dst_type import is appropriate for ref matching.


100-103: Exports added.

GenerateHeatmapd and aliases exported as expected.


677-678: Aliases defined.

Consistent with MONAI aliasing pattern.

Enhances the GenerateHeatmap transform with better normalization,
spatial metadata handling, and comprehensive documentation.

The changes ensure correct heatmap normalization, and improve
handling of spatial metadata inheritance from reference images.
Also improves input validation and fixes shape inconsistencies.

Adds new test cases to cover edge cases and improve code reliability.
Signed-off-by: sewon.jeon <[email protected]>
Fixes an issue where heatmap normalization was using the entire heatmap instead of the local region.

Adds a check to ensure that the provided static shape matches the number of spatial dimensions.

Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 40c492d to 1b5888b Compare September 26, 2025 15:32
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (8)
monai/transforms/post/array.py (3)

778-779: Annotate mutable class attribute backend with ClassVar to satisfy Ruff (RUF012).

Minor typing fix; consistent with linter guidance.

Apply this diff:

-    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
+    backend: ClassVar[list[TransformBackends]] = [TransformBackends.NUMPY, TransformBackends.TORCH]

And add the import (outside this hunk):

from typing import ClassVar

830-847: Avoid unnecessary CPU sync and use torch finiteness checks.

.tolist() on CUDA tensors syncs to CPU. Use torch checks first; convert after.

Apply this diff:

-                center_vals = center.tolist()
-                if not np.all(np.isfinite(center_vals)):
-                    continue
+                if not torch.isfinite(center).all():
+                    continue
+                center_vals = [float(c) for c in center.detach().cpu().tolist()]

753-776: Docstring polish and TRY003 linter noise.

Messages are clear but long raises trigger TRY003. Optional: shorten messages or centralize them as class constants.

monai/transforms/post/dictionary.py (5)

553-554: Annotate backend with ClassVar to satisfy RUF012.

Minor typing improvement.

Apply this diff:

-    backend = GenerateHeatmap.backend
+    backend: ClassVar = GenerateHeatmap.backend

And add (outside this hunk) if not present in this module:

from typing import ClassVar

583-605: MetaTensor propagation: ensure MetaTensor is kept after conversion.

convert_to_dst_type(..., dtype=heatmap.dtype, device=reference.device) may return a plain Tensor if dtype/device match but meta wasn’t tracked. Safer to explicitly wrap to MetaTensor when reference is MetaTensor.

Apply this diff:

-            if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):
+            if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):
                 # Convert to match reference type and device while preserving heatmap's dtype
                 heatmap, _, _ = convert_to_dst_type(
                     heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None)
                 )
                 # Copy metadata if reference is MetaTensor
-                if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
-                    heatmap.affine = reference.affine
-                    self._update_spatial_metadata(heatmap, shape)
+                if isinstance(reference, MetaTensor):
+                    if not isinstance(heatmap, MetaTensor):
+                        heatmap = MetaTensor(heatmap, affine=reference.affine, meta=reference.meta.copy())
+                    else:
+                        heatmap.affine = reference.affine
+                    self._update_spatial_metadata(heatmap, shape)

648-664: Shorten/centralize error strings (TRY003).

Optional: reduce f-string length or reuse class-level constants to silence linter.


516-551: Docstring clarity: unify axis notation and output shape description.

Use consistent Z/Y/X naming across 2D/3D (matches array variant) to avoid confusion with H/W/D wording.


583-605: Shape source precedence: warn when both static shape and ref are provided but disagree.

Currently static shape silently wins. Consider logging a warning if both exist and shapes differ.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 40c492d and 1b5888b.

📒 Files selected for processing (3)
  • monai/transforms/post/array.py (3 hunks)
  • monai/transforms/post/dictionary.py (4 hunks)
  • tests/transforms/test_generate_heatmapd.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/transforms/test_generate_heatmapd.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/transforms/post/dictionary.py
  • monai/transforms/post/array.py
🪛 Ruff (0.13.1)
monai/transforms/post/dictionary.py

653-653: Avoid specifying long messages outside the exception class

(TRY003)


657-659: Avoid specifying long messages outside the exception class

(TRY003)

monai/transforms/post/array.py

778-778: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


790-790: Avoid specifying long messages outside the exception class

(TRY003)


794-794: Avoid specifying long messages outside the exception class

(TRY003)


797-797: Avoid specifying long messages outside the exception class

(TRY003)


811-813: Avoid specifying long messages outside the exception class

(TRY003)


817-817: Avoid specifying long messages outside the exception class

(TRY003)


858-858: Avoid specifying long messages outside the exception class

(TRY003)


864-866: Avoid specifying long messages outside the exception class

(TRY003)


874-874: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (3)
monai/transforms/post/array.py (1)

41-48: LGTM on utils import expansion.

get_equivalent_dtype is appropriate here.

monai/transforms/post/dictionary.py (2)

38-39: Importing array-level transform is correct.

Keeps wrapper aligned with core implementation.


52-53: Using convert_to_dst_type is appropriate for dtype/device alignment.

No issues.

Adds a `GenerateHeatmap` transform to generate heatmaps from point data.

This transform creates heatmaps from point data, validating that the dtype
is a floating-point type.

Signed-off-by: sewon.jeon <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
monai/transforms/post/array.py (4)

803-805: Future-proof the dtype check.

Use the dtype’s is_floating_point attribute to avoid hard-coding allowed types.

Apply this diff:

-        # Validate that dtype is floating-point for meaningful Gaussian values
-        if self.torch_dtype not in (torch.float16, torch.float32, torch.float64, torch.bfloat16):
-            raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}")
+        # Validate that dtype is floating-point for meaningful Gaussian values
+        if not getattr(self.torch_dtype, "is_floating_point", False):
+            raise ValueError("dtype must be a floating-point type (e.g., torch.float32 / np.float32).")

829-851: Avoid CPU sync and NumPy in the inner loop (GPU-friendly check).

Replace tolist()/np.isfinite with torch-native checks and precompute bounds tensor once per call. Reduces host-device sync and improves performance on CUDA.

Apply this diff:

         heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device)
         image_bounds = tuple(int(s) for s in target_shape)
+        bounds_t = torch.as_tensor(image_bounds, device=device, dtype=points_t.dtype)
         for b_idx in range(batch_size):
             for idx, center in enumerate(points_t[b_idx]):
-                center_vals = center.tolist()
-                if not np.all(np.isfinite(center_vals)):
-                    continue
-                if not self._is_inside(center_vals, image_bounds):
-                    continue
-                window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
+                if not torch.isfinite(center).all():
+                    continue
+                if not ((center >= 0).all() and (center < bounds_t).all()):
+                    continue
+                # _make_window expects Python floats; pass only when needed
+                window_slices, coord_shifts = self._make_window(center.tolist(), radius, image_bounds, device)
                 if window_slices is None:
                     continue
                 region = heatmap[b_idx, idx][window_slices]
                 gaussian = self._evaluate_gaussian(coord_shifts, sigma)
                 updated = torch.maximum(region, gaussian)
                 # write back
                 region.copy_(updated)
                 if self.normalize:
                     peak = heatmap[b_idx, idx].amax()
                     denom = torch.where(peak > 0, peak, torch.ones_like(peak))
                     heatmap[b_idx, idx].div_(denom)

807-857: Missing method docstring for call.

Add a short docstring describing inputs/outputs and shapes to match neighboring transforms.

Apply this diff:

     def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
+        """
+        Args:
+            points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D),
+                ordered as (Y, X) for 2D or (Z, Y, X) for 3D.
+            spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses
+                the value provided at construction.
+        Returns:
+            Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark.
+        Raises:
+            ValueError: if points shape/dimension or spatial_shape is invalid.
+        """

778-778: Optional: annotate backend as ClassVar to satisfy lint.

If you follow Ruff’s RUF012, annotate backend as ClassVar. This file has other instances without it, so adopt consistently if you enable the rule globally.

Apply this diff within the class:

-    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
+    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]  # consider: backend: ClassVar[list[TransformBackends]] = [...]

And add at the top of the file if standardizing:

from typing import ClassVar

Based on learnings

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1b5888b and fd4be38.

📒 Files selected for processing (2)
  • monai/transforms/__init__.py (2 hunks)
  • monai/transforms/post/array.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/transforms/post/array.py
  • monai/transforms/__init__.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py

778-778: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


790-790: Avoid specifying long messages outside the exception class

(TRY003)


794-794: Avoid specifying long messages outside the exception class

(TRY003)


797-797: Avoid specifying long messages outside the exception class

(TRY003)


804-804: Avoid specifying long messages outside the exception class

(TRY003)


814-816: Avoid specifying long messages outside the exception class

(TRY003)


820-820: Avoid specifying long messages outside the exception class

(TRY003)


861-861: Avoid specifying long messages outside the exception class

(TRY003)


867-869: Avoid specifying long messages outside the exception class

(TRY003)


877-877: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: packaging
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
🔇 Additional comments (4)
monai/transforms/__init__.py (2)

296-296: Top-level export for GenerateHeatmap looks good.

Re-export added correctly; aligns with new public API.


323-326: Dict-variant exports added correctly.

GenerateHeatmapd/GenerateHeatmapD/GenerateHeatmapDict are re-exported as expected.

monai/transforms/post/array.py (2)

41-48: Import additions are appropriate.

get_equivalent_dtype and related utils are correctly imported for the new transform.


64-64: Public export updated.

Adding "GenerateHeatmap" to all is correct.

Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 0fec3ad to fc28c71 Compare September 26, 2025 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Heatmap generation transforms
1 participant