@@ -72,14 +72,7 @@ def test_array_torch_device_and_dtype_propagation(self):
72
72
73
73
def test_array_channel_order_identity (self ):
74
74
# ensure the order of channels follows the order of input points
75
- pts = np .array (
76
- [
77
- [2.0 , 2.0 ], # point A
78
- [12.0 , 2.0 ], # point B
79
- [2.0 , 12.0 ], # point C
80
- ],
81
- dtype = np .float32 ,
82
- )
75
+ pts = np .array ([[2.0 , 2.0 ], [12.0 , 2.0 ], [2.0 , 12.0 ]], dtype = np .float32 ) # point A # point B # point C
83
76
hm = GenerateHeatmap (sigma = 1.2 , spatial_shape = (16 , 16 ))(pts )
84
77
self .assertEqual (hm .shape , (3 , 16 , 16 ))
85
78
@@ -90,11 +83,7 @@ def test_array_channel_order_identity(self):
90
83
def test_array_points_out_of_bounds (self ):
91
84
# points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros
92
85
pts = np .array (
93
- [
94
- [- 5.0 , - 5.0 ], # outside top-left
95
- [100.0 , 100.0 ], # outside bottom-right
96
- [8.0 , 8.0 ], # inside
97
- ],
86
+ [[- 5.0 , - 5.0 ], [100.0 , 100.0 ], [8.0 , 8.0 ]], # outside top-left # outside bottom-right # inside
98
87
dtype = np .float32 ,
99
88
)
100
89
hm = GenerateHeatmap (sigma = 2.0 , spatial_shape = (16 , 16 ))(pts )
@@ -118,12 +107,7 @@ def test_dict_with_reference_meta(self):
118
107
image .meta ["spatial_shape" ] = (8 , 8 , 8 )
119
108
data = {"points" : points , "image" : image }
120
109
121
- transform = GenerateHeatmapd (
122
- keys = "points" ,
123
- heatmap_keys = "heatmap" ,
124
- ref_image_keys = "image" ,
125
- sigma = 2.0 ,
126
- )
110
+ transform = GenerateHeatmapd (keys = "points" , heatmap_keys = "heatmap" , ref_image_keys = "image" , sigma = 2.0 )
127
111
128
112
result = transform (data )
129
113
heatmap = result ["heatmap" ]
@@ -172,13 +156,7 @@ def test_dict_dtype_control(self):
172
156
self .assertEqual (hm .dtype , torch .float16 )
173
157
174
158
def test_array_batched_3d (self ):
175
- points = np .array (
176
- [
177
- [[4.2 , 7.8 , 1.0 ]], # Batch 1
178
- [[12.3 , 3.6 , 2.0 ]], # Batch 2
179
- ],
180
- dtype = np .float32 ,
181
- )
159
+ points = np .array ([[[4.2 , 7.8 , 1.0 ]], [[12.3 , 3.6 , 2.0 ]]], dtype = np .float32 ) # Batch 1 # Batch 2
182
160
transform = GenerateHeatmap (sigma = 1.5 , spatial_shape = (16 , 16 , 16 ))
183
161
184
162
heatmap = transform (points )
@@ -193,25 +171,14 @@ def test_array_batched_3d(self):
193
171
self .assertTrue (np .all (np .abs (peak - points [i , 0 ]) <= 1.0 ), msg = f"peak={ peak } , point={ points [i , 0 ]} " )
194
172
195
173
def test_dict_batched_with_ref (self ):
196
- points = torch .tensor (
197
- [
198
- [[1.5 , 2.5 , 3.5 ]], # Batch 1
199
- [[4.5 , 5.5 , 6.5 ]], # Batch 2
200
- ],
201
- dtype = torch .float32 ,
202
- )
174
+ points = torch .tensor ([[[1.5 , 2.5 , 3.5 ]], [[4.5 , 5.5 , 6.5 ]]], dtype = torch .float32 ) # Batch 1 # Batch 2
203
175
affine = torch .eye (4 )
204
176
# A single reference image is used for the whole batch
205
177
image = MetaTensor (torch .zeros ((1 , 8 , 8 , 8 ), dtype = torch .float32 ), affine = affine )
206
178
image .meta ["spatial_shape" ] = (8 , 8 , 8 )
207
179
data = {"points" : points , "image" : image }
208
180
209
- transform = GenerateHeatmapd (
210
- keys = "points" ,
211
- heatmap_keys = "heatmap" ,
212
- ref_image_keys = "image" ,
213
- sigma = 1.0 ,
214
- )
181
+ transform = GenerateHeatmapd (keys = "points" , heatmap_keys = "heatmap" , ref_image_keys = "image" , sigma = 1.0 )
215
182
216
183
result = transform (data )
217
184
heatmap = result ["heatmap" ]
0 commit comments