Skip to content

Commit 15ec97a

Browse files
committed
fix formatting
Signed-off-by: sewon.jeon <[email protected]>
1 parent 62831e6 commit 15ec97a

File tree

3 files changed

+10
-59
lines changed

3 files changed

+10
-59
lines changed

monai/transforms/post/array.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -792,11 +792,7 @@ def __init__(
792792
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
793793
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)
794794

795-
def __call__(
796-
self,
797-
points: NdarrayOrTensor,
798-
spatial_shape: Sequence[int] | None = None,
799-
) -> NdarrayOrTensor:
795+
def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
800796
original_points = points
801797
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
802798

@@ -871,11 +867,7 @@ def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
871867
return all(0 <= c < size for c, size in zip(center, bounds))
872868

873869
def _make_window(
874-
self,
875-
center: Sequence[float],
876-
radius: tuple[int, ...],
877-
bounds: tuple[int, ...],
878-
device: torch.device,
870+
self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device
879871
) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]:
880872
slices: list[slice] = []
881873
coord_shifts: list[torch.Tensor] = []

monai/transforms/post/dictionary.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,7 @@ def __init__(
538538
self.ref_image_keys = self._prepare_optional_keys(ref_image_keys)
539539
self.static_shapes = self._prepare_shapes(spatial_shape)
540540
self.generator = GenerateHeatmap(
541-
sigma=sigma,
542-
spatial_shape=None,
543-
truncated=truncated,
544-
normalize=normalize,
545-
dtype=dtype,
541+
sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype
546542
)
547543

548544
def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
@@ -600,11 +596,7 @@ def _prepare_shapes(
600596
return tuple(prepared)
601597

602598
def _determine_shape(
603-
self,
604-
points: Any,
605-
static_shape: tuple[int, ...] | None,
606-
data: Mapping[Hashable, Any],
607-
ref_key: Hashable | None,
599+
self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None
608600
) -> tuple[int, ...]:
609601
if static_shape is not None:
610602
return static_shape

tests/transforms/test_generate_heatmap.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,7 @@ def test_array_torch_device_and_dtype_propagation(self):
7272

7373
def test_array_channel_order_identity(self):
7474
# ensure the order of channels follows the order of input points
75-
pts = np.array(
76-
[
77-
[2.0, 2.0], # point A
78-
[12.0, 2.0], # point B
79-
[2.0, 12.0], # point C
80-
],
81-
dtype=np.float32,
82-
)
75+
pts = np.array([[2.0, 2.0], [12.0, 2.0], [2.0, 12.0]], dtype=np.float32) # point A # point B # point C
8376
hm = GenerateHeatmap(sigma=1.2, spatial_shape=(16, 16))(pts)
8477
self.assertEqual(hm.shape, (3, 16, 16))
8578

@@ -90,11 +83,7 @@ def test_array_channel_order_identity(self):
9083
def test_array_points_out_of_bounds(self):
9184
# points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros
9285
pts = np.array(
93-
[
94-
[-5.0, -5.0], # outside top-left
95-
[100.0, 100.0], # outside bottom-right
96-
[8.0, 8.0], # inside
97-
],
86+
[[-5.0, -5.0], [100.0, 100.0], [8.0, 8.0]], # outside top-left # outside bottom-right # inside
9887
dtype=np.float32,
9988
)
10089
hm = GenerateHeatmap(sigma=2.0, spatial_shape=(16, 16))(pts)
@@ -118,12 +107,7 @@ def test_dict_with_reference_meta(self):
118107
image.meta["spatial_shape"] = (8, 8, 8)
119108
data = {"points": points, "image": image}
120109

121-
transform = GenerateHeatmapd(
122-
keys="points",
123-
heatmap_keys="heatmap",
124-
ref_image_keys="image",
125-
sigma=2.0,
126-
)
110+
transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=2.0)
127111

128112
result = transform(data)
129113
heatmap = result["heatmap"]
@@ -172,13 +156,7 @@ def test_dict_dtype_control(self):
172156
self.assertEqual(hm.dtype, torch.float16)
173157

174158
def test_array_batched_3d(self):
175-
points = np.array(
176-
[
177-
[[4.2, 7.8, 1.0]], # Batch 1
178-
[[12.3, 3.6, 2.0]], # Batch 2
179-
],
180-
dtype=np.float32,
181-
)
159+
points = np.array([[[4.2, 7.8, 1.0]], [[12.3, 3.6, 2.0]]], dtype=np.float32) # Batch 1 # Batch 2
182160
transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16, 16))
183161

184162
heatmap = transform(points)
@@ -193,25 +171,14 @@ def test_array_batched_3d(self):
193171
self.assertTrue(np.all(np.abs(peak - points[i, 0]) <= 1.0), msg=f"peak={peak}, point={points[i, 0]}")
194172

195173
def test_dict_batched_with_ref(self):
196-
points = torch.tensor(
197-
[
198-
[[1.5, 2.5, 3.5]], # Batch 1
199-
[[4.5, 5.5, 6.5]], # Batch 2
200-
],
201-
dtype=torch.float32,
202-
)
174+
points = torch.tensor([[[1.5, 2.5, 3.5]], [[4.5, 5.5, 6.5]]], dtype=torch.float32) # Batch 1 # Batch 2
203175
affine = torch.eye(4)
204176
# A single reference image is used for the whole batch
205177
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)
206178
image.meta["spatial_shape"] = (8, 8, 8)
207179
data = {"points": points, "image": image}
208180

209-
transform = GenerateHeatmapd(
210-
keys="points",
211-
heatmap_keys="heatmap",
212-
ref_image_keys="image",
213-
sigma=1.0,
214-
)
181+
transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=1.0)
215182

216183
result = transform(data)
217184
heatmap = result["heatmap"]

0 commit comments

Comments
 (0)