-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Generate heatmap transforms #8579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Generate heatmap transforms #8579
Conversation
Adds a `GenerateHeatmap` transform to create gaussian response maps from landmark coordinates. This transform is implemented for both array and dictionary-based workflows. It enables the generation of heatmaps from landmark data, facilitating tasks like landmark localization and visualization. The transform supports 2D and 3D coordinates and offers options for controlling the gaussian standard deviation, spatial shape, truncation, normalization, and data type.
Introduces a new interactive notebook demonstrating landmark to heatmap conversion using MONAI transforms. This includes: - A notebook with array and dictionary transform modes. - A test suite for the `GenerateHeatmap` transform. This enhancement enables users to visualize and interact with heatmap generation, facilitating a better understanding and application of the MONAI transforms.
Extends the `GenerateHeatmap` transform to support batched inputs, allowing for more efficient processing of multiple landmark sets. This change modifies the transform to handle inputs with a batch dimension (B, N, spatial_dims) in addition to single-point inputs (N, spatial_dims). It also includes a demonstration of 3D heatmap generation using PyVista for visualization.
for more information, see https://pre-commit.ci
Streamlines the GenerateHeatmap and GenerateHeatmapd transforms for better usability and code clarity. Specifically: - Improves the input landmark array validation to provide a more descriptive error message. - Removes example notebooks. DCO Remediation Commit for sewon.jeon <[email protected]> I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 8ef905b I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 226bf90 I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 3097baf I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 0072cb0 Signed-off-by: sewon.jeon <[email protected]>
0072cb0
to
25ceb7f
Compare
Signed-off-by: sewon.jeon <[email protected]>
4443705
to
9e33e7c
Compare
for more information, see https://pre-commit.ci
Warning Rate limit exceeded@eclipse0922 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 18 minutes and 30 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (1)
WalkthroughAdds a new public GenerateHeatmap Transform in monai/transforms/post/array.py that produces per-landmark Gaussian heatmaps for 2D/3D (batched) inputs. Supports scalar or per-dimension sigma (with broadcasting), optional truncation (local windowed evaluation), spatial_shape provided at construction or call, optional normalization, and output dtype resolution via get_equivalent_dtype. Adds GenerateHeatmapd (aliases GenerateHeatmapD / GenerateHeatmapDict) in monai/transforms/post/dictionary.py to apply GenerateHeatmap per key, normalize key mappings, infer per-key spatial_shape (including from a reference), convert outputs to a reference dtype/device via convert_to_dst_type, and propagate MetaTensor metadata. Updates all exports and adds comprehensive unit tests for both array and dictionary variants. Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
60b58f5
to
54a81a5
Compare
for more information, see https://pre-commit.ci
Signed-off-by: sewon.jeon <[email protected]>
@coderabbitai review |
@coderabbitai help |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (8)
monai/transforms/post/dictionary.py (2)
516-652
: LGTM with minor improvements needed.The implementation is solid. A few observations:
The
_update_spatial_metadata
method has complex logic for determining spatial shape from different tensor dimensions. Consider simplifying or adding more explicit comments about the logic.Error messages could be extracted to constants for better maintainability (lines 573, 583, 598, 619-621, 633).
The type hints could be more specific -
Any
is used extensively where more concrete types might be known.Consider extracting error messages:
+_ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length." +_ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided." +_ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes." +_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." +_ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)." +_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute." def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]: if heatmap_keys is None: return tuple(f"{key}_heatmap" for key in self.keys) keys_tuple = ensure_tuple(heatmap_keys) if len(keys_tuple) == 1 and len(self.keys) > 1: keys_tuple = keys_tuple * len(self.keys) if len(keys_tuple) != len(self.keys): - raise ValueError("heatmap_keys length must match keys length.") + raise ValueError(_ERR_HEATMAP_KEYS_LEN) return keys_tuple
636-650
: Simplify spatial metadata update logic.The
_update_spatial_metadata
method has nested conditionals that make it hard to follow. The logic for distinguishing batched 2D from non-batched 3D is particularly complex.Consider a clearer approach:
def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None: """Update spatial metadata of heatmap based on its dimensions.""" - # Update spatial_shape metadata based on heatmap dimensions - if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D) - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) - elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D) - # Need to check if this is batched 2D or non-batched 3D - if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])): - # Non-batched 3D - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) - else: - # Batched 2D - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) - else: # 2D non-batched: (C, H, W) - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) + # Determine if batched based on reference's batch dimension + ref_is_batched = len(reference.shape) > len(reference.meta.get("spatial_shape", [])) + 1 + + if heatmap.ndim == 5: # 3D batched + spatial_shape = heatmap.shape[2:] + elif heatmap.ndim == 4: + # Disambiguate: 2D batched vs 3D non-batched + spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:] + else: # ndim == 3, 2D non-batched + spatial_shape = heatmap.shape[1:] + + heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)tests/transforms/test_generate_heatmapd.py (3)
101-101
: Remove unused parameters.Parameters
expected_dtype
anduses_ref
are not used in the test method.-def test_dict_with_reference_meta(self, _, points, params, expected_shape, expected_dtype, uses_ref): +def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused):
151-151
: Remove unused parameter.Parameter
expected_dtype
is not used.-def test_dict_batched_with_ref(self, _, points, params, expected_shape, expected_dtype): +def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expected_dtype):
205-229
: Document current behavior limitation.The test acknowledges that MetaTensor points may inherit incorrect affine. This should be tracked as a known issue.
Should I create an issue to track the affine inheritance behavior when using MetaTensor points with reference images?
monai/transforms/post/array.py (3)
753-893
: Well-implemented transform with room for minor improvements.The
GenerateHeatmap
transform is well-structured. A few suggestions:
- The
backend
class attribute should be annotated withClassVar
(line 769).- Consider extracting error messages to constants for maintainability.
- The
_evaluate_gaussian
method could benefit from a brief docstring.Fix the class attribute annotation:
+from typing import ClassVar class GenerateHeatmap(Transform): - backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + backend: ClassVar[list] = [TransformBackends.NUMPY, TransformBackends.TORCH]
863-865
: Add boundary check optimization.The
_is_inside
method could short-circuit on first failure.@staticmethod def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool: - return all(0 <= c < size for c, size in zip(center, bounds)) + for c, size in zip(center, bounds): + if not (0 <= c < size): + return False + return True
881-892
: Add docstring for clarity.The
_evaluate_gaussian
method would benefit from documentation.def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: + """Evaluate Gaussian at given coordinate shifts with specified sigmas. + + Args: + coord_shifts: Per-dimension coordinate offsets from center. + sigma: Per-dimension standard deviations. + + Returns: + Gaussian values at the specified coordinates. + """ device = coord_shifts[0].device
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (4)
monai/transforms/post/array.py
(3 hunks)monai/transforms/post/dictionary.py
(4 hunks)tests/transforms/test_generate_heatmap.py
(1 hunks)tests/transforms/test_generate_heatmapd.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/transforms/post/dictionary.py
monai/transforms/post/array.py
tests/transforms/test_generate_heatmap.py
tests/transforms/test_generate_heatmapd.py
🪛 Ruff (0.13.1)
monai/transforms/post/dictionary.py
573-573: Avoid specifying long messages outside the exception class
(TRY003)
583-583: Avoid specifying long messages outside the exception class
(TRY003)
598-598: Avoid specifying long messages outside the exception class
(TRY003)
615-615: Avoid specifying long messages outside the exception class
(TRY003)
619-621: Avoid specifying long messages outside the exception class
(TRY003)
633-633: Avoid specifying long messages outside the exception class
(TRY003)
monai/transforms/post/array.py
769-769: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
781-781: Avoid specifying long messages outside the exception class
(TRY003)
785-785: Avoid specifying long messages outside the exception class
(TRY003)
788-788: Avoid specifying long messages outside the exception class
(TRY003)
802-804: Avoid specifying long messages outside the exception class
(TRY003)
808-808: Avoid specifying long messages outside the exception class
(TRY003)
847-847: Avoid specifying long messages outside the exception class
(TRY003)
853-853: Avoid specifying long messages outside the exception class
(TRY003)
861-861: Avoid specifying long messages outside the exception class
(TRY003)
tests/transforms/test_generate_heatmapd.py
56-56: Consider (1, *shape)
instead of concatenation
Replace with (1, *shape)
(RUF005)
101-101: Unused method argument: expected_dtype
(ARG002)
101-101: Unused method argument: uses_ref
(ARG002)
151-151: Unused method argument: expected_dtype
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: packaging
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
🔇 Additional comments (2)
tests/transforms/test_generate_heatmap.py (2)
24-29
: LGTM!Clean helper function for finding peak coordinates.
233-240
: Good error handling test coverage.Testing multiple invalid input scenarios is thorough.
Signed-off-by: sewon.jeon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (12)
monai/transforms/post/array.py (5)
754-769
: Clarify coordinate order and output shape in docstring.Please document the expected coordinate order (e.g., ZYX/XY), units (voxels), and output shapes:
- Non-batched: (N, H, W[, D])
- Batched: (B, N, H, W[, D])
class GenerateHeatmap(Transform): - """ - Generate per-landmark gaussian response maps for 2D or 3D coordinates. - ... - """ + """ + Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates. + + Notes: + - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D. + - Output shape: + - Non-batched points (N, D): (N, H, W[, D]) + - Batched points (B, N, D): (B, N, H, W[, D]) + - Each channel corresponds to one landmark. + + Args: + sigma: ... + ... + + Raises: + ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved. + """
796-844
: Avoid full-channel scans for normalization; compute peak within the updated window.This reduces O(HWD) scans per point to O(window) while maintaining behavior (one Gaussian per channel).
- region = heatmap[b_idx, idx][window_slices] - gaussian = self._evaluate_gaussian(coord_shifts, sigma) - torch.maximum(region, gaussian, out=region) - if self.normalize: - max_val = heatmap[b_idx, idx].max() - if max_val.item() > 0: - heatmap[b_idx, idx] /= max_val + region = heatmap[b_idx, idx][window_slices] + gaussian = self._evaluate_gaussian(coord_shifts, sigma) + updated = torch.maximum(region, gaussian) + # write back + region.copy_(updated) + if self.normalize: + peak = updated.max() + if peak.item() > 0: + heatmap[b_idx, idx] /= peak
871-907
: Compute Gaussian in float32 for numerical stability, then cast to target dtype.Half precision can underflow/overflow for the exponent. Do the math in float32 and cast at the end.
- coord_shifts.append(torch.arange(start, stop, device=device, dtype=self.torch_dtype) - float(c)) + coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c))- def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: + def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: """ Evaluate Gaussian at given coordinate shifts with specified sigmas. ... """ - device = coord_shifts[0].device - shape = tuple(len(axis) for axis in coord_shifts) - if 0 in shape: - return torch.zeros(shape, dtype=self.torch_dtype, device=device) - exponent = torch.zeros(shape, dtype=self.torch_dtype, device=device) + device = coord_shifts[0].device + shape = tuple(len(axis) for axis in coord_shifts) + if 0 in shape: + return torch.zeros(shape, dtype=self.torch_dtype, device=device) + exponent = torch.zeros(shape, dtype=torch.float32, device=device) for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)): - scaled = (shift / float(sig)) ** 2 + shift32 = shift.to(torch.float32) + scaled = (shift32 / float(sig)) ** 2 reshape_shape = [1] * len(coord_shifts) reshape_shape[dim] = shift.numel() exponent += scaled.reshape(reshape_shape) - return torch.exp(-0.5 * exponent) + gauss = torch.exp(-0.5 * exponent) + return gauss.to(dtype=self.torch_dtype)
845-856
: Permit integer spatial_shape broadcasting in the call doc and error text.Consider clarifying in error text that a single int is broadcast across dims (already implemented).
- raise ValueError("spatial_shape length must match spatial dimension of the landmarks.") + raise ValueError( + "spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)." + )
772-795
: Ruff TRY003: shorten exception messages or centralize them.Minor lint: several raises have relatively long string literals. You can shorten or move to constants if you care about TRY003.
tests/transforms/test_generate_heatmapd.py (3)
56-56
: Use tuple unpacking for clarity.Prefer (1, *shape) over (1,) + shape.
- (1,) + shape, + (1, *shape),
168-171
: Generalize max over spatial dims without hardcoding.Flatten spatial dims to compute per-(B,C) maxima.
- max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0] - np.testing.assert_allclose( - max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5 - ) + hm2 = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], -1) + max_vals = hm2.max(dim=2)[0] + np.testing.assert_allclose(max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5)
205-229
: Align affine to the reference image when provided.Test name and comment imply ref affine should win. Recommend adjusting the transform to explicitly set heatmap.affine = reference.affine when a ref MetaTensor is used. See suggested change in dictionary.py.
monai/transforms/post/dictionary.py (4)
532-551
: Broaden dtype parameter type to accept torch.dtype.Matches the array transform signature and current tests.
- dtype: np.dtype | type = np.float32, + dtype: np.dtype | torch.dtype | type = np.float32,
552-573
: Ensure affine alignment to reference MetaTensor.Explicitly copy the affine when reference is MetaTensor, in addition to spatial metadata.
if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): # Convert to match reference type and device while preserving heatmap's dtype heatmap, _, _ = convert_to_dst_type( heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None) ) # Copy metadata if reference is MetaTensor if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): + heatmap.affine = reference.affine self._update_spatial_metadata(heatmap, reference)
516-523
: Docstring: add output shapes and per-key behavior.Document the default for heatmap_keys, shape inference precedence (static_shape > ref_image), and output shapes for batched/non-batched.
641-656
: Simplify spatial_shape derivation.You can reliably set spatial_shape as the trailing spatial dims regardless of batch presence:
- ref_spatial_shape = reference.meta.get("spatial_shape", []) - ref_is_batched = len(reference.shape) > len(ref_spatial_shape) + 1 - - if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D) - spatial_shape = heatmap.shape[2:] - elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D) - # Disambiguate: 2D batched vs 3D non-batched - spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:] - else: # 2D non-batched: (C, H, W) - spatial_shape = heatmap.shape[1:] + # trailing dims after channel are spatial regardless of batch presence + spatial_shape = heatmap.shape[-(reference.ndim - 1):]
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (3)
monai/transforms/post/array.py
(4 hunks)monai/transforms/post/dictionary.py
(4 hunks)tests/transforms/test_generate_heatmapd.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/transforms/post/array.py
monai/transforms/post/dictionary.py
tests/transforms/test_generate_heatmapd.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py
782-782: Avoid specifying long messages outside the exception class
(TRY003)
786-786: Avoid specifying long messages outside the exception class
(TRY003)
789-789: Avoid specifying long messages outside the exception class
(TRY003)
803-805: Avoid specifying long messages outside the exception class
(TRY003)
809-809: Avoid specifying long messages outside the exception class
(TRY003)
848-848: Avoid specifying long messages outside the exception class
(TRY003)
854-854: Avoid specifying long messages outside the exception class
(TRY003)
862-862: Avoid specifying long messages outside the exception class
(TRY003)
tests/transforms/test_generate_heatmapd.py
56-56: Consider (1, *shape)
instead of concatenation
Replace with (1, *shape)
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: build-docs
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
🔇 Additional comments (1)
monai/transforms/post/array.py (1)
42-49
: Good use of dtype mapping for backend parity.Using get_equivalent_dtype to consistently map numpy/torch dtypes is the right call here.
Signed-off-by: sewon.jeon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (12)
tests/transforms/test_generate_heatmapd.py (4)
121-130
: Add dtype assertions for static-shape cases.You pass expected_dtype but only assert equality with numpy’s dtype for static cases; consider also asserting no NaN/Inf and max == 1.0 for parity with other tests.
131-136
: Verify error message content.Also assert the exception string contains a helpful hint (e.g., “Provide spatial_shape or ref_image_keys”) to lock behavior.
186-205
: Add negative/mismatch cases for key lengths.Please add tests for:
- Mismatched heatmap_keys length → ValueError.
- Mismatched ref_image_keys length → ValueError.
- Per-key spatial_shape length mismatch → ValueError.
These hit the explicit validation branches.
206-231
: Align test with current behavior: affine comes from reference.The comment suggests the heatmap “may inherit from points”, but the implementation sets heatmap.affine from the reference. Prefer asserting that explicitly.
Apply this diff:
- # Note: Currently the heatmap may inherit affine from points MetaTensor - # This test documents the current behavior - # Ideally, the heatmap should use the reference image's affine + # Heatmap should inherit affine from the reference image + assert_allclose(heatmap.affine, image.affine, type_test=False)monai/transforms/post/dictionary.py (3)
516-527
: Docstring lacks parameter/return details.Briefly document keys, heatmap_keys/ref_image_keys broadcasting rules, dtype/device behavior, and raised exceptions to match the rest of this module.
559-581
: Spatial metadata should use the resolved shape (not reference.ndim).Relying on reference.ndim-1 to infer spatial dims can be brittle with unusual channel layouts/batch dims. You already have the resolved shape; pass it through.
Apply this diff (two locations):
@@ - if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): - heatmap.affine = reference.affine - self._update_spatial_metadata(heatmap, reference) + if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): + heatmap.affine = reference.affine + self._update_spatial_metadata(heatmap, shape)And update the helper:
@@ - def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None: - """Update spatial metadata of heatmap based on its dimensions.""" - # trailing dims after channel are spatial regardless of batch presence - spatial_shape = heatmap.shape[-(reference.ndim - 1) :] - heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) + def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int, ...]) -> None: + """Set spatial_shape explicitly from resolved shape.""" + heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
624-636
: Error path for invalid points: message clarity.Consider appending the actual ndim found to the error string to help users debug malformed inputs.
monai/transforms/post/array.py (5)
756-775
: Clarify coordinate order and output shape semantics.Docstring is good; suggest explicitly stating that target_shape is (Z, Y, X) for 3D and (Y, X) for 2D, and output layout is channel-first per-point heatmaps (N, …) or (B, N, …).
825-846
: Avoid device sync on normalization.peak.item() forces host sync on GPU. Keep math on-device to reduce stalls.
Apply this diff:
- if self.normalize: - peak = updated.max() - if peak.item() > 0: - heatmap[b_idx, idx] /= peak + if self.normalize: + peak = updated.amax() + denom = torch.where(peak > 0, peak, torch.ones_like(peak)) + heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom
854-867
: Spatial shape resolution UX.Broadcasting a single int to D dims is handy; consider adding that behavior to the class docstring for discoverability.
896-920
: Gaussian evaluation is correct; minor micro-opts possible.Current implementation is clear; if needed, precompute reshape shapes outside the loop to shave overhead in hot paths.
787-797
: Ruff TRY003: long messages in exceptions.Linters flag long exception strings; optional to shorten or centralize messages, but not a blocker.
Also applies to: 810-817, 857-865, 873-873
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (3)
monai/transforms/post/array.py
(4 hunks)monai/transforms/post/dictionary.py
(4 hunks)tests/transforms/test_generate_heatmapd.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/transforms/post/dictionary.py
monai/transforms/post/array.py
tests/transforms/test_generate_heatmapd.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py
789-789: Avoid specifying long messages outside the exception class
(TRY003)
793-793: Avoid specifying long messages outside the exception class
(TRY003)
796-796: Avoid specifying long messages outside the exception class
(TRY003)
810-812: Avoid specifying long messages outside the exception class
(TRY003)
816-816: Avoid specifying long messages outside the exception class
(TRY003)
857-857: Avoid specifying long messages outside the exception class
(TRY003)
863-865: Avoid specifying long messages outside the exception class
(TRY003)
873-873: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (6)
tests/transforms/test_generate_heatmapd.py (1)
99-120
: Solid coverage for reference-driven behavior.The assertions on MetaTensor type, shape, affine propagation, and normalization are on point.
monai/transforms/post/dictionary.py (2)
592-601
: Key broadcasting logic: good and consistent.Tuplifying and broadcasting single entries to len(keys) matches established patterns in this file.
602-623
: Shape normalization handles all call forms.Accepts None, single shape, or per-key shapes robustly; clear error on length mismatch.
monai/transforms/post/array.py (3)
787-804
: Dtype mapping via get_equivalent_dtype: LGTM.Robust handling for numpy/torch dtypes; positive checks for sigma/truncated are correct.
803-813
: Input validation: good guardrails.Batched and non-batched shapes validated with clear messaging; 2D/3D restriction enforced.
850-852
: Type/device conversion: correct and consistent.Returning to the original container type with appropriate dtype is consistent with MONAI conventions.
bce6263
to
40c492d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
monai/transforms/post/array.py (1)
754-922
: Heatmap generation looks correct; tighten normalization and minor polish.The algorithm and I/O handling are solid. One tweak: normalize using the channel’s amax (post-write) instead of the window’s amax to be robust and clearer.
Apply this diff:
@@ - if self.normalize: - peak = updated.amax() - denom = torch.where(peak > 0, peak, torch.ones_like(peak)) - heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom + if self.normalize: + peak = heatmap[b_idx, idx].amax() + denom = torch.where(peak > 0, peak, torch.ones_like(peak)) + heatmap[b_idx, idx].div_(denom)Optional:
- Factor exception messages into class-level constants (like in GenerateHeatmapd) to satisfy Ruff TRY003. As per coding guidelines and static analysis hints.
monai/transforms/post/dictionary.py (1)
516-679
: Fix sigma doc and validate static_shape rank
- Doc: change “sequence matching number of points” → “sequence matching the number of spatial dimensions.”
- In
_determine_shape
, convert points first, computespatial_dims
, then ifstatic_shape
is set raiseValueError
whenlen(static_shape) != spatial_dims
.@@ class GenerateHeatmapd(MapTransform): - sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points. + sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number of spatial dimensions.@@ def _determine_shape( - if static_shape is not None: - return static_shape + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + if points_t.ndim not in (2, 3): + raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") + spatial_dims = int(points_t.shape[-1]) + if static_shape is not None: + if len(static_shape) != spatial_dims: + raise ValueError(f"Provided static spatial_shape has {len(static_shape)} dims; expected {spatial_dims}.") + return static_shape if ref_key is not None and ref_key in data: return self._shape_from_reference(data[ref_key], spatial_dims) raise ValueError(self._ERR_NO_SHAPE)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (3)
monai/transforms/post/array.py
(4 hunks)monai/transforms/post/dictionary.py
(4 hunks)tests/transforms/test_generate_heatmapd.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/transforms/test_generate_heatmapd.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/transforms/post/array.py
monai/transforms/post/dictionary.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py
791-791: Avoid specifying long messages outside the exception class
(TRY003)
795-795: Avoid specifying long messages outside the exception class
(TRY003)
798-798: Avoid specifying long messages outside the exception class
(TRY003)
812-814: Avoid specifying long messages outside the exception class
(TRY003)
818-818: Avoid specifying long messages outside the exception class
(TRY003)
859-859: Avoid specifying long messages outside the exception class
(TRY003)
865-867: Avoid specifying long messages outside the exception class
(TRY003)
875-875: Avoid specifying long messages outside the exception class
(TRY003)
monai/transforms/post/dictionary.py
654-654: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (7)
monai/transforms/post/array.py (3)
19-19
: Good: backend typing clarified.Using ClassVar for backend is consistent with MONAI style.
42-49
: Good: dtype and tensor utils imported where needed.Imports align with new dtype handling in GenerateHeatmap.
65-65
: Public export added.Including GenerateHeatmap in all looks correct.
monai/transforms/post/dictionary.py (4)
38-38
: Import of array-level transform is correct.GenerateHeatmap is referenced directly and used as the engine.
52-52
: Type/device alignment hook imported.convert_to_dst_type import is appropriate for ref matching.
100-103
: Exports added.GenerateHeatmapd and aliases exported as expected.
677-678
: Aliases defined.Consistent with MONAI aliasing pattern.
Enhances the GenerateHeatmap transform with better normalization, spatial metadata handling, and comprehensive documentation. The changes ensure correct heatmap normalization, and improve handling of spatial metadata inheritance from reference images. Also improves input validation and fixes shape inconsistencies. Adds new test cases to cover edge cases and improve code reliability. Signed-off-by: sewon.jeon <[email protected]>
Fixes an issue where heatmap normalization was using the entire heatmap instead of the local region. Adds a check to ensure that the provided static shape matches the number of spatial dimensions. Signed-off-by: sewon.jeon <[email protected]>
40c492d
to
1b5888b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (8)
monai/transforms/post/array.py (3)
778-779
: Annotate mutable class attributebackend
with ClassVar to satisfy Ruff (RUF012).Minor typing fix; consistent with linter guidance.
Apply this diff:
- backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + backend: ClassVar[list[TransformBackends]] = [TransformBackends.NUMPY, TransformBackends.TORCH]And add the import (outside this hunk):
from typing import ClassVar
830-847
: Avoid unnecessary CPU sync and use torch finiteness checks.
.tolist()
on CUDA tensors syncs to CPU. Use torch checks first; convert after.Apply this diff:
- center_vals = center.tolist() - if not np.all(np.isfinite(center_vals)): - continue + if not torch.isfinite(center).all(): + continue + center_vals = [float(c) for c in center.detach().cpu().tolist()]
753-776
: Docstring polish and TRY003 linter noise.Messages are clear but long raises trigger TRY003. Optional: shorten messages or centralize them as class constants.
monai/transforms/post/dictionary.py (5)
553-554
: Annotatebackend
with ClassVar to satisfy RUF012.Minor typing improvement.
Apply this diff:
- backend = GenerateHeatmap.backend + backend: ClassVar = GenerateHeatmap.backendAnd add (outside this hunk) if not present in this module:
from typing import ClassVar
583-605
: MetaTensor propagation: ensure MetaTensor is kept after conversion.
convert_to_dst_type(..., dtype=heatmap.dtype, device=reference.device)
may return a plain Tensor if dtype/device match but meta wasn’t tracked. Safer to explicitly wrap to MetaTensor when reference is MetaTensor.Apply this diff:
- if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): + if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): # Convert to match reference type and device while preserving heatmap's dtype heatmap, _, _ = convert_to_dst_type( heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None) ) # Copy metadata if reference is MetaTensor - if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): - heatmap.affine = reference.affine - self._update_spatial_metadata(heatmap, shape) + if isinstance(reference, MetaTensor): + if not isinstance(heatmap, MetaTensor): + heatmap = MetaTensor(heatmap, affine=reference.affine, meta=reference.meta.copy()) + else: + heatmap.affine = reference.affine + self._update_spatial_metadata(heatmap, shape)
648-664
: Shorten/centralize error strings (TRY003).Optional: reduce f-string length or reuse class-level constants to silence linter.
516-551
: Docstring clarity: unify axis notation and output shape description.Use consistent Z/Y/X naming across 2D/3D (matches array variant) to avoid confusion with H/W/D wording.
583-605
: Shape source precedence: warn when both static shape and ref are provided but disagree.Currently static shape silently wins. Consider logging a warning if both exist and shapes differ.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (3)
monai/transforms/post/array.py
(3 hunks)monai/transforms/post/dictionary.py
(4 hunks)tests/transforms/test_generate_heatmapd.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/transforms/test_generate_heatmapd.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/transforms/post/dictionary.py
monai/transforms/post/array.py
🪛 Ruff (0.13.1)
monai/transforms/post/dictionary.py
653-653: Avoid specifying long messages outside the exception class
(TRY003)
657-659: Avoid specifying long messages outside the exception class
(TRY003)
monai/transforms/post/array.py
778-778: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
790-790: Avoid specifying long messages outside the exception class
(TRY003)
794-794: Avoid specifying long messages outside the exception class
(TRY003)
797-797: Avoid specifying long messages outside the exception class
(TRY003)
811-813: Avoid specifying long messages outside the exception class
(TRY003)
817-817: Avoid specifying long messages outside the exception class
(TRY003)
858-858: Avoid specifying long messages outside the exception class
(TRY003)
864-866: Avoid specifying long messages outside the exception class
(TRY003)
874-874: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (3)
monai/transforms/post/array.py (1)
41-48
: LGTM on utils import expansion.
get_equivalent_dtype
is appropriate here.monai/transforms/post/dictionary.py (2)
38-39
: Importing array-level transform is correct.Keeps wrapper aligned with core implementation.
52-53
: Using convert_to_dst_type is appropriate for dtype/device alignment.No issues.
Adds a `GenerateHeatmap` transform to generate heatmaps from point data. This transform creates heatmaps from point data, validating that the dtype is a floating-point type. Signed-off-by: sewon.jeon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/transforms/post/array.py (4)
803-805
: Future-proof the dtype check.Use the dtype’s is_floating_point attribute to avoid hard-coding allowed types.
Apply this diff:
- # Validate that dtype is floating-point for meaningful Gaussian values - if self.torch_dtype not in (torch.float16, torch.float32, torch.float64, torch.bfloat16): - raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}") + # Validate that dtype is floating-point for meaningful Gaussian values + if not getattr(self.torch_dtype, "is_floating_point", False): + raise ValueError("dtype must be a floating-point type (e.g., torch.float32 / np.float32).")
829-851
: Avoid CPU sync and NumPy in the inner loop (GPU-friendly check).Replace tolist()/np.isfinite with torch-native checks and precompute bounds tensor once per call. Reduces host-device sync and improves performance on CUDA.
Apply this diff:
heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device) image_bounds = tuple(int(s) for s in target_shape) + bounds_t = torch.as_tensor(image_bounds, device=device, dtype=points_t.dtype) for b_idx in range(batch_size): for idx, center in enumerate(points_t[b_idx]): - center_vals = center.tolist() - if not np.all(np.isfinite(center_vals)): - continue - if not self._is_inside(center_vals, image_bounds): - continue - window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) + if not torch.isfinite(center).all(): + continue + if not ((center >= 0).all() and (center < bounds_t).all()): + continue + # _make_window expects Python floats; pass only when needed + window_slices, coord_shifts = self._make_window(center.tolist(), radius, image_bounds, device) if window_slices is None: continue region = heatmap[b_idx, idx][window_slices] gaussian = self._evaluate_gaussian(coord_shifts, sigma) updated = torch.maximum(region, gaussian) # write back region.copy_(updated) if self.normalize: peak = heatmap[b_idx, idx].amax() denom = torch.where(peak > 0, peak, torch.ones_like(peak)) heatmap[b_idx, idx].div_(denom)
807-857
: Missing method docstring for call.Add a short docstring describing inputs/outputs and shapes to match neighboring transforms.
Apply this diff:
def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor: + """ + Args: + points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D), + ordered as (Y, X) for 2D or (Z, Y, X) for 3D. + spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses + the value provided at construction. + Returns: + Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark. + Raises: + ValueError: if points shape/dimension or spatial_shape is invalid. + """
778-778
: Optional: annotate backend as ClassVar to satisfy lint.If you follow Ruff’s RUF012, annotate backend as ClassVar. This file has other instances without it, so adopt consistently if you enable the rule globally.
Apply this diff within the class:
- backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] # consider: backend: ClassVar[list[TransformBackends]] = [...]And add at the top of the file if standardizing:
from typing import ClassVarBased on learnings
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (2)
monai/transforms/__init__.py
(2 hunks)monai/transforms/post/array.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/transforms/post/array.py
monai/transforms/__init__.py
🪛 Ruff (0.13.1)
monai/transforms/post/array.py
778-778: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
790-790: Avoid specifying long messages outside the exception class
(TRY003)
794-794: Avoid specifying long messages outside the exception class
(TRY003)
797-797: Avoid specifying long messages outside the exception class
(TRY003)
804-804: Avoid specifying long messages outside the exception class
(TRY003)
814-816: Avoid specifying long messages outside the exception class
(TRY003)
820-820: Avoid specifying long messages outside the exception class
(TRY003)
861-861: Avoid specifying long messages outside the exception class
(TRY003)
867-869: Avoid specifying long messages outside the exception class
(TRY003)
877-877: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: packaging
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
🔇 Additional comments (4)
monai/transforms/__init__.py (2)
296-296
: Top-level export for GenerateHeatmap looks good.Re-export added correctly; aligns with new public API.
323-326
: Dict-variant exports added correctly.GenerateHeatmapd/GenerateHeatmapD/GenerateHeatmapDict are re-exported as expected.
monai/transforms/post/array.py (2)
41-48
: Import additions are appropriate.get_equivalent_dtype and related utils are correctly imported for the new transform.
64-64
: Public export updated.Adding "GenerateHeatmap" to all is correct.
Signed-off-by: sewon.jeon <[email protected]>
0fec3ad
to
fc28c71
Compare
Fixes #3328 .
Description
A few sentences describing the changes proposed in this pull request.
This pull request introduces
GenerateHeatmap
andGenerateHeatmapd
transforms for creating Gaussian heatmaps from landmark coordinates.The input points are currently expected in ZYX order, but this can be changed to support XYZ if preferred.
The transforms support both batched (B, N, D) and non-batched (N, D) inputs.
Example notebooks are included for demonstration and will be removed before the PR is merged.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.