Skip to content

Commit 40c492d

Browse files
committed
Improves GenerateHeatmap transform and documentation
Enhances the GenerateHeatmap transform with better normalization, spatial metadata handling, and comprehensive documentation. The changes ensure correct heatmap normalization, and improve handling of spatial metadata inheritance from reference images. Also improves input validation and fixes shape inconsistencies. Adds new test cases to cover edge cases and improve code reliability. Signed-off-by: sewon.jeon <[email protected]>
1 parent 1bf0850 commit 40c492d

File tree

3 files changed

+73
-16
lines changed

3 files changed

+73
-16
lines changed

monai/transforms/post/array.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -757,14 +757,16 @@ class GenerateHeatmap(Transform):
757757
758758
Notes:
759759
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
760-
- Output shape:
761-
- Non-batched points (N, D): (N, H, W[, D])
762-
- Batched points (B, N, D): (B, N, H, W[, D])
760+
- Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
761+
- Output layout uses channel-first convention with one channel per landmark:
762+
- Non-batched points (N, D): (N, Y, X) for 2D or (N, Z, Y, X) for 3D
763+
- Batched points (B, N, D): (B, N, Y, X) for 2D or (B, N, Z, Y, X) for 3D
763764
- Each channel corresponds to one landmark.
764765
765766
Args:
766767
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
767768
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
769+
A single int value will be broadcast to all spatial dimensions.
768770
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
769771
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
770772
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).
@@ -840,9 +842,9 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
840842
# write back
841843
region.copy_(updated)
842844
if self.normalize:
843-
peak = updated.max()
844-
if peak.item() > 0:
845-
heatmap[b_idx, idx] /= peak
845+
peak = updated.amax()
846+
denom = torch.where(peak > 0, peak, torch.ones_like(peak))
847+
heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom
846848

847849
if not is_batched:
848850
heatmap = heatmap.squeeze(0)

monai/transforms/post/dictionary.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,35 @@ class GenerateHeatmapd(MapTransform):
518518
Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`.
519519
Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.
520520
521+
Args:
522+
keys: keys of the corresponding items in the dictionary.
523+
sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points.
524+
heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
525+
ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will
526+
have the same shape, affine, and spatial metadata as the reference images.
527+
spatial_shape: spatial dimensions of output heatmaps. Can be:
528+
- Single shape (tuple): applied to all keys
529+
- List of shapes: one per key (must match keys length)
530+
truncated: truncation distance for Gaussian kernel computation (in sigmas).
531+
normalize: if True, normalize each heatmap's peak value to 1.0.
532+
dtype: output data type for heatmaps. Defaults to np.float32.
533+
allow_missing_keys: if True, don't raise error if some keys are missing in data.
534+
535+
Returns:
536+
Dictionary with original data plus generated heatmaps at specified keys.
537+
538+
Raises:
539+
ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.
540+
ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).
541+
ValueError: If input points have invalid shape (must be 2D or 3D).
542+
521543
Notes:
522544
- Default heatmap_keys are generated as "{key}_heatmap" for each input key
523545
- Shape inference precedence: static spatial_shape > ref_image
524546
- Output shapes:
525547
- Non-batched points (N, D): (N, H, W[, D])
526548
- Batched points (B, N, D): (B, N, H, W[, D])
549+
- When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference
527550
"""
528551

529552
backend = GenerateHeatmap.backend
@@ -575,7 +598,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
575598
# Copy metadata if reference is MetaTensor
576599
if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
577600
heatmap.affine = reference.affine
578-
self._update_spatial_metadata(heatmap, reference)
601+
self._update_spatial_metadata(heatmap, shape)
579602
d[out_key] = heatmap
580603
return d
581604

@@ -628,7 +651,7 @@ def _determine_shape(
628651
return static_shape
629652
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
630653
if points_t.ndim not in (2, 3):
631-
raise ValueError(self._ERR_INVALID_POINTS)
654+
raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")
632655
spatial_dims = int(points_t.shape[-1])
633656
if ref_key is not None and ref_key in data:
634657
return self._shape_from_reference(data[ref_key], spatial_dims)
@@ -646,10 +669,8 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
646669
return tuple(int(v) for v in reference.shape[-spatial_dims:])
647670
raise ValueError(self._ERR_REF_NO_SHAPE)
648671

649-
def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
650-
"""Update spatial metadata of heatmap based on its dimensions."""
651-
# trailing dims after channel are spatial regardless of batch presence
652-
spatial_shape = heatmap.shape[-(reference.ndim - 1) :]
672+
def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int, ...]) -> None:
673+
"""Set spatial_shape explicitly from resolved shape."""
653674
heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
654675

655676

tests/transforms/test_generate_heatmapd.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,16 @@ def test_dict_static_shape(self, _, points, params, expected_shape, expected_dty
128128
self.assertEqual(heatmap.shape, expected_shape)
129129
self.assertEqual(heatmap.dtype, expected_dtype)
130130

131+
# Verify no NaN or Inf values
132+
self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any())
133+
134+
# Verify max value is 1.0 for normalized heatmaps
135+
np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5)
136+
131137
def test_dict_missing_shape_raises(self):
132138
# Without ref image or explicit spatial_shape, must raise
133139
transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap")
134-
with self.assertRaises(ValueError):
140+
with self.assertRaisesRegex(ValueError, "spatial_shape|ref_image_keys"):
135141
transform({"points": np.zeros((1, 2), dtype=np.float32)})
136142

137143
@parameterized.expand(TEST_CASES_DTYPE)
@@ -203,6 +209,35 @@ def test_dict_multiple_keys(self):
203209
# Verify peaks are at different locations
204210
self.assertNotEqual(np.argmax(result["hm1"]), np.argmax(result["hm2"]))
205211

212+
def test_dict_mismatched_heatmap_keys_length(self):
213+
"""Test ValueError when heatmap_keys length doesn't match keys"""
214+
with self.assertRaises(ValueError):
215+
GenerateHeatmapd(
216+
keys=["pts1", "pts2"],
217+
heatmap_keys=["hm1", "hm2", "hm3"], # Mismatch: 3 heatmap keys for 2 input keys
218+
spatial_shape=(8, 8),
219+
)
220+
221+
def test_dict_mismatched_ref_image_keys_length(self):
222+
"""Test ValueError when ref_image_keys length doesn't match keys"""
223+
with self.assertRaises(ValueError):
224+
GenerateHeatmapd(
225+
keys=["pts1", "pts2"],
226+
heatmap_keys=["hm1", "hm2"],
227+
ref_image_keys=["img1", "img2", "img3"], # Mismatch: 3 ref keys for 2 input keys
228+
spatial_shape=(8, 8),
229+
)
230+
231+
def test_dict_per_key_spatial_shape_mismatch(self):
232+
"""Test ValueError when per-key spatial_shape length doesn't match keys"""
233+
with self.assertRaises(ValueError):
234+
GenerateHeatmapd(
235+
keys=["pts1", "pts2"],
236+
heatmap_keys=["hm1", "hm2"],
237+
spatial_shape=[(8, 8), (8, 8), (8, 8)], # Mismatch: 3 shapes for 2 keys
238+
sigma=1.0,
239+
)
240+
206241
def test_metatensor_points_with_ref(self):
207242
"""Test MetaTensor points with reference image - documents current behavior"""
208243
from monai.data import MetaTensor
@@ -224,9 +259,8 @@ def test_metatensor_points_with_ref(self):
224259
self.assertIsInstance(heatmap, MetaTensor)
225260
self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8))
226261

227-
# Note: Currently the heatmap may inherit affine from points MetaTensor
228-
# This test documents the current behavior
229-
# Ideally, the heatmap should use the reference image's affine
262+
# Heatmap should inherit affine from the reference image
263+
assert_allclose(heatmap.affine, image.affine, type_test=False)
230264

231265

232266
if __name__ == "__main__":

0 commit comments

Comments
 (0)