Skip to content

Commit 0e907bb

Browse files
committed
fix meta tensor problem
Signed-off-by: sewon.jeon <[email protected]>
1 parent 5bc7993 commit 0e907bb

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

monai/transforms/post/array.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def __init__(
783783
else:
784784
if float(sigma) <= 0:
785785
raise ValueError("sigma must be positive.")
786-
self._sigma = float(sigma)
786+
self._sigma = (float(sigma),)
787787
if truncated <= 0:
788788
raise ValueError("truncated must be positive.")
789789
self.truncated = float(truncated)
@@ -826,7 +826,7 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
826826
window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
827827
if window_slices is None:
828828
continue
829-
region = heatmap[(b_idx, idx, *window_slices)]
829+
region = heatmap[b_idx, idx][window_slices]
830830
gaussian = self._evaluate_gaussian(coord_shifts, sigma)
831831
torch.maximum(region, gaussian, out=region)
832832
if self.normalize:
@@ -854,13 +854,11 @@ def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims:
854854
return tuple(int(s) for s in shape_tuple)
855855

856856
def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
857-
if isinstance(self._sigma, tuple):
858-
if len(self._sigma) == spatial_dims:
859-
return self._sigma
860-
if len(self._sigma) == 1:
861-
return self._sigma * spatial_dims
862-
raise ValueError("sigma sequence length must equal the number of spatial dimensions.")
863-
return (self._sigma,) * spatial_dims
857+
if len(self._sigma) == spatial_dims:
858+
return self._sigma
859+
if len(self._sigma) == 1:
860+
return self._sigma * spatial_dims
861+
raise ValueError("sigma sequence length must equal the number of spatial dimensions.")
864862

865863
@staticmethod
866864
def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:

monai/transforms/post/dictionary.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,19 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
548548
):
549549
points = d[key]
550550
shape = self._determine_shape(points, static_shape, d, ref_key)
551+
# The GenerateHeatmap transform will handle type conversion based on input points
551552
heatmap = self.generator(points, spatial_shape=shape)
553+
# If there's a reference image and we need to match its type/device
552554
reference = d.get(ref_key) if ref_key is not None and ref_key in d else None
553-
d[out_key] = self._prepare_output(heatmap, reference)
555+
if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):
556+
# Convert to match reference type and device while preserving heatmap's dtype
557+
heatmap, _, _ = convert_to_dst_type(
558+
heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None)
559+
)
560+
# Copy metadata if reference is MetaTensor
561+
if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
562+
self._update_spatial_metadata(heatmap, reference)
563+
d[out_key] = heatmap
554564
return d
555565

556566
def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:
@@ -622,29 +632,21 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
622632
return tuple(int(v) for v in reference.shape[-spatial_dims:])
623633
raise ValueError("Reference data must define a shape attribute.")
624634

625-
def _prepare_output(self, heatmap: NdarrayOrTensor, reference: Any) -> Any:
626-
if isinstance(reference, MetaTensor):
627-
# Use heatmap's dtype (from generator), not reference's dtype
628-
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device)
629-
# For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial)
630-
if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D)
631-
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
632-
elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
633-
# Need to check if this is batched 2D or non-batched 3D
634-
if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])):
635-
# Non-batched 3D
636-
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
637-
else:
638-
# Batched 2D
639-
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
640-
else: # 2D non-batched: (C, H, W)
641-
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
642-
return converted
643-
if isinstance(reference, torch.Tensor):
644-
# Use heatmap's dtype (from generator), not reference's dtype
645-
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device)
646-
return converted
647-
return heatmap
635+
def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
636+
"""Update spatial metadata of heatmap based on its dimensions."""
637+
# Update spatial_shape metadata based on heatmap dimensions
638+
if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D)
639+
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
640+
elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641+
# Need to check if this is batched 2D or non-batched 3D
642+
if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])):
643+
# Non-batched 3D
644+
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
645+
else:
646+
# Batched 2D
647+
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
648+
else: # 2D non-batched: (C, H, W)
649+
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
648650

649651

650652
GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd

0 commit comments

Comments
 (0)