@@ -528,7 +528,7 @@ def __init__(
528
528
heatmap_keys : KeysCollection | None = None ,
529
529
ref_image_keys : KeysCollection | None = None ,
530
530
spatial_shape : Sequence [int ] | Sequence [Sequence [int ]] | None = None ,
531
- truncate : float = 3 .0 ,
531
+ truncated : float = 4 .0 ,
532
532
normalize : bool = True ,
533
533
dtype : np .dtype | type = np .float32 ,
534
534
allow_missing_keys : bool = False ,
@@ -540,7 +540,7 @@ def __init__(
540
540
self .generator = GenerateHeatmap (
541
541
sigma = sigma ,
542
542
spatial_shape = None ,
543
- truncate = truncate ,
543
+ truncated = truncated ,
544
544
normalize = normalize ,
545
545
dtype = dtype ,
546
546
)
@@ -632,11 +632,25 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
632
632
633
633
def _prepare_output (self , heatmap : NdarrayOrTensor , reference : Any ) -> Any :
634
634
if isinstance (reference , MetaTensor ):
635
- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = reference .dtype , device = reference .device )
636
- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
635
+ # Use heatmap's dtype (from generator), not reference's dtype
636
+ converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
637
+ # For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial)
638
+ if heatmap .ndim == 5 : # 3D batched: (B, C, H, W, D)
639
+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
640
+ elif heatmap .ndim == 4 : # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641
+ # Need to check if this is batched 2D or non-batched 3D
642
+ if len (heatmap .shape [1 :]) == len (reference .meta .get ("spatial_shape" , [])):
643
+ # Non-batched 3D
644
+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
645
+ else :
646
+ # Batched 2D
647
+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
648
+ else : # 2D non-batched: (C, H, W)
649
+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
637
650
return converted
638
651
if isinstance (reference , torch .Tensor ):
639
- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = reference .dtype , device = reference .device )
652
+ # Use heatmap's dtype (from generator), not reference's dtype
653
+ converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
640
654
return converted
641
655
return heatmap
642
656
0 commit comments