@@ -548,9 +548,19 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
548
548
):
549
549
points = d [key ]
550
550
shape = self ._determine_shape (points , static_shape , d , ref_key )
551
+ # The GenerateHeatmap transform will handle type conversion based on input points
551
552
heatmap = self .generator (points , spatial_shape = shape )
553
+ # If there's a reference image and we need to match its type/device
552
554
reference = d .get (ref_key ) if ref_key is not None and ref_key in d else None
553
- d [out_key ] = self ._prepare_output (heatmap , reference )
555
+ if reference is not None and isinstance (reference , (torch .Tensor , np .ndarray )):
556
+ # Convert to match reference type and device while preserving heatmap's dtype
557
+ heatmap , _ , _ = convert_to_dst_type (
558
+ heatmap , reference , dtype = heatmap .dtype , device = getattr (reference , "device" , None )
559
+ )
560
+ # Copy metadata if reference is MetaTensor
561
+ if isinstance (reference , MetaTensor ) and isinstance (heatmap , MetaTensor ):
562
+ self ._update_spatial_metadata (heatmap , reference )
563
+ d [out_key ] = heatmap
554
564
return d
555
565
556
566
def _prepare_heatmap_keys (self , heatmap_keys : KeysCollection | None ) -> tuple [Hashable , ...]:
@@ -622,29 +632,21 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
622
632
return tuple (int (v ) for v in reference .shape [- spatial_dims :])
623
633
raise ValueError ("Reference data must define a shape attribute." )
624
634
625
- def _prepare_output (self , heatmap : NdarrayOrTensor , reference : Any ) -> Any :
626
- if isinstance (reference , MetaTensor ):
627
- # Use heatmap's dtype (from generator), not reference's dtype
628
- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
629
- # For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial)
630
- if heatmap .ndim == 5 : # 3D batched: (B, C, H, W, D)
631
- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
632
- elif heatmap .ndim == 4 : # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
633
- # Need to check if this is batched 2D or non-batched 3D
634
- if len (heatmap .shape [1 :]) == len (reference .meta .get ("spatial_shape" , [])):
635
- # Non-batched 3D
636
- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
637
- else :
638
- # Batched 2D
639
- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
640
- else : # 2D non-batched: (C, H, W)
641
- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
642
- return converted
643
- if isinstance (reference , torch .Tensor ):
644
- # Use heatmap's dtype (from generator), not reference's dtype
645
- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
646
- return converted
647
- return heatmap
635
+ def _update_spatial_metadata (self , heatmap : MetaTensor , reference : MetaTensor ) -> None :
636
+ """Update spatial metadata of heatmap based on its dimensions."""
637
+ # Update spatial_shape metadata based on heatmap dimensions
638
+ if heatmap .ndim == 5 : # 3D batched: (B, C, H, W, D)
639
+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
640
+ elif heatmap .ndim == 4 : # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641
+ # Need to check if this is batched 2D or non-batched 3D
642
+ if len (heatmap .shape [1 :]) == len (reference .meta .get ("spatial_shape" , [])):
643
+ # Non-batched 3D
644
+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
645
+ else :
646
+ # Batched 2D
647
+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
648
+ else : # 2D non-batched: (C, H, W)
649
+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
648
650
649
651
650
652
GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd
0 commit comments