@@ -518,12 +518,35 @@ class GenerateHeatmapd(MapTransform):
518
518
Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`.
519
519
Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.
520
520
521
+ Args:
522
+ keys: keys of the corresponding items in the dictionary.
523
+ sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points.
524
+ heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
525
+ ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will
526
+ have the same shape, affine, and spatial metadata as the reference images.
527
+ spatial_shape: spatial dimensions of output heatmaps. Can be:
528
+ - Single shape (tuple): applied to all keys
529
+ - List of shapes: one per key (must match keys length)
530
+ truncated: truncation distance for Gaussian kernel computation (in sigmas).
531
+ normalize: if True, normalize each heatmap's peak value to 1.0.
532
+ dtype: output data type for heatmaps. Defaults to np.float32.
533
+ allow_missing_keys: if True, don't raise error if some keys are missing in data.
534
+
535
+ Returns:
536
+ Dictionary with original data plus generated heatmaps at specified keys.
537
+
538
+ Raises:
539
+ ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.
540
+ ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).
541
+ ValueError: If input points have invalid shape (must be 2D or 3D).
542
+
521
543
Notes:
522
544
- Default heatmap_keys are generated as "{key}_heatmap" for each input key
523
545
- Shape inference precedence: static spatial_shape > ref_image
524
546
- Output shapes:
525
547
- Non-batched points (N, D): (N, H, W[, D])
526
548
- Batched points (B, N, D): (B, N, H, W[, D])
549
+ - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference
527
550
"""
528
551
529
552
backend = GenerateHeatmap .backend
@@ -575,7 +598,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
575
598
# Copy metadata if reference is MetaTensor
576
599
if isinstance (reference , MetaTensor ) and isinstance (heatmap , MetaTensor ):
577
600
heatmap .affine = reference .affine
578
- self ._update_spatial_metadata (heatmap , reference )
601
+ self ._update_spatial_metadata (heatmap , shape )
579
602
d [out_key ] = heatmap
580
603
return d
581
604
@@ -628,7 +651,7 @@ def _determine_shape(
628
651
return static_shape
629
652
points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
630
653
if points_t .ndim not in (2 , 3 ):
631
- raise ValueError (self ._ERR_INVALID_POINTS )
654
+ raise ValueError (f" { self ._ERR_INVALID_POINTS } Got { points_t . ndim } D tensor." )
632
655
spatial_dims = int (points_t .shape [- 1 ])
633
656
if ref_key is not None and ref_key in data :
634
657
return self ._shape_from_reference (data [ref_key ], spatial_dims )
@@ -646,10 +669,8 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
646
669
return tuple (int (v ) for v in reference .shape [- spatial_dims :])
647
670
raise ValueError (self ._ERR_REF_NO_SHAPE )
648
671
649
- def _update_spatial_metadata (self , heatmap : MetaTensor , reference : MetaTensor ) -> None :
650
- """Update spatial metadata of heatmap based on its dimensions."""
651
- # trailing dims after channel are spatial regardless of batch presence
652
- spatial_shape = heatmap .shape [- (reference .ndim - 1 ) :]
672
+ def _update_spatial_metadata (self , heatmap : MetaTensor , spatial_shape : tuple [int , ...]) -> None :
673
+ """Set spatial_shape explicitly from resolved shape."""
653
674
heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in spatial_shape )
654
675
655
676
0 commit comments