Skip to content

Commit 1b5888b

Browse files
committed
Fixes heatmap normalization and shape checking
Fixes an issue where heatmap normalization was using the entire heatmap instead of the local region. Adds a check to ensure that the provided static shape matches the number of spatial dimensions. Signed-off-by: sewon.jeon <[email protected]>
1 parent 2c7b4d0 commit 1b5888b

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

monai/transforms/post/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,9 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
841841
# write back
842842
region.copy_(updated)
843843
if self.normalize:
844-
peak = updated.amax()
844+
peak = heatmap[b_idx, idx].amax()
845845
denom = torch.where(peak > 0, peak, torch.ones_like(peak))
846-
heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom
846+
heatmap[b_idx, idx].div_(denom)
847847

848848
if not is_batched:
849849
heatmap = heatmap.squeeze(0)

monai/transforms/post/dictionary.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,8 @@ class GenerateHeatmapd(MapTransform):
520520
521521
Args:
522522
keys: keys of the corresponding items in the dictionary.
523-
sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points.
523+
sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number
524+
of spatial dimensions.
524525
heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
525526
ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will
526527
have the same shape, affine, and spatial metadata as the reference images.
@@ -647,12 +648,16 @@ def _prepare_shapes(
647648
def _determine_shape(
648649
self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None
649650
) -> tuple[int, ...]:
650-
if static_shape is not None:
651-
return static_shape
652651
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
653652
if points_t.ndim not in (2, 3):
654653
raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")
655654
spatial_dims = int(points_t.shape[-1])
655+
if static_shape is not None:
656+
if len(static_shape) != spatial_dims:
657+
raise ValueError(
658+
f"Provided static spatial_shape has {len(static_shape)} dims; expected {spatial_dims}."
659+
)
660+
return static_shape
656661
if ref_key is not None and ref_key in data:
657662
return self._shape_from_reference(data[ref_key], spatial_dims)
658663
raise ValueError(self._ERR_NO_SHAPE)

0 commit comments

Comments
 (0)