Skip to content

Commit 20bbf6b

Browse files
committed
rename parameter
1 parent 25ceb7f commit 20bbf6b

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

monai/transforms/post/array.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ class GenerateHeatmap(Transform):
757757
Args:
758758
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
759759
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
760-
truncate: extent, in multiples of ``sigma``, used to crop the gaussian support window.
760+
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
761761
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
762762
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).
763763
@@ -772,7 +772,7 @@ def __init__(
772772
self,
773773
sigma: Sequence[float] | float = 5.0,
774774
spatial_shape: Sequence[int] | None = None,
775-
truncate: float = 3.0,
775+
truncated: float = 4.0,
776776
normalize: bool = True,
777777
dtype: np.dtype | torch.dtype | type = np.float32,
778778
) -> None:
@@ -784,9 +784,9 @@ def __init__(
784784
if float(sigma) <= 0:
785785
raise ValueError("sigma must be positive.")
786786
self._sigma = float(sigma)
787-
if truncate <= 0:
788-
raise ValueError("truncate must be positive.")
789-
self.truncate = float(truncate)
787+
if truncated <= 0:
788+
raise ValueError("truncated must be positive.")
789+
self.truncated = float(truncated)
790790
self.normalize = normalize
791791
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
792792
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
@@ -816,7 +816,7 @@ def __call__(
816816

817817
target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
818818
sigma = self._resolve_sigma(spatial_dims)
819-
radius = tuple(int(np.ceil(self.truncate * s)) for s in sigma)
819+
radius = tuple(int(np.ceil(self.truncated * s)) for s in sigma)
820820

821821
heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device)
822822
image_bounds = tuple(int(s) for s in target_shape)

monai/transforms/post/dictionary.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def __init__(
528528
heatmap_keys: KeysCollection | None = None,
529529
ref_image_keys: KeysCollection | None = None,
530530
spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None,
531-
truncate: float = 3.0,
531+
truncated: float = 4.0,
532532
normalize: bool = True,
533533
dtype: np.dtype | type = np.float32,
534534
allow_missing_keys: bool = False,
@@ -540,7 +540,7 @@ def __init__(
540540
self.generator = GenerateHeatmap(
541541
sigma=sigma,
542542
spatial_shape=None,
543-
truncate=truncate,
543+
truncated=truncated,
544544
normalize=normalize,
545545
dtype=dtype,
546546
)
@@ -632,11 +632,25 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
632632

633633
def _prepare_output(self, heatmap: NdarrayOrTensor, reference: Any) -> Any:
634634
if isinstance(reference, MetaTensor):
635-
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device)
636-
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
635+
# Use heatmap's dtype (from generator), not reference's dtype
636+
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device)
637+
# For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial)
638+
if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D)
639+
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
640+
elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641+
# Need to check if this is batched 2D or non-batched 3D
642+
if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])):
643+
# Non-batched 3D
644+
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
645+
else:
646+
# Batched 2D
647+
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
648+
else: # 2D non-batched: (C, H, W)
649+
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
637650
return converted
638651
if isinstance(reference, torch.Tensor):
639-
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device)
652+
# Use heatmap's dtype (from generator), not reference's dtype
653+
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device)
640654
return converted
641655
return heatmap
642656

tests/transforms/test_generate_heatmap.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,25 @@ def test_dict_batched_with_ref(self):
223223
max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0]
224224
np.testing.assert_allclose(max_vals.cpu().numpy(), np.ones((2, 1)), rtol=1e-5, atol=1e-5)
225225

226+
def test_truncated_parameter(self):
227+
# Test that truncated parameter correctly controls window size
228+
pt = np.array([[8.0, 8.0]], dtype=np.float32)
229+
sigma = 2.0
230+
231+
# Test with different truncated values
232+
small_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=2.0)(pt)[0]
233+
default_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=4.0)(pt)[0] # default
234+
large_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=6.0)(pt)[0]
235+
236+
# Larger truncated should capture more of the gaussian, resulting in slightly higher total sum
237+
self.assertLess(small_truncated.sum(), default_truncated.sum())
238+
self.assertLess(default_truncated.sum(), large_truncated.sum())
239+
240+
# All should have same peak value (normalized to 1.0)
241+
np.testing.assert_allclose(small_truncated.max(), 1.0, rtol=1e-5)
242+
np.testing.assert_allclose(default_truncated.max(), 1.0, rtol=1e-5)
243+
np.testing.assert_allclose(large_truncated.max(), 1.0, rtol=1e-5)
244+
226245

227246
if __name__ == "__main__":
228247
unittest.main()

0 commit comments

Comments
 (0)