Skip to content

Commit 5de03b9

Browse files
committed
update tests
1 parent d6d4cfc commit 5de03b9

File tree

2 files changed

+64
-16
lines changed

2 files changed

+64
-16
lines changed

pytorch_sparse_utils/conversion.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,9 @@ def torch_sparse_to_minkowski(tensor: Tensor):
6363
assert isinstance(tensor, Tensor)
6464
assert tensor.is_sparse
6565
features = tensor.values()
66-
coordinates = tensor.indices()
6766
if features.ndim == 1:
6867
features = features.unsqueeze(-1)
69-
coordinates = coordinates[:-1]
70-
coordinates = coordinates.transpose(0, 1).contiguous().int()
68+
coordinates = tensor.indices().T.int().contiguous()
7169
return ME.SparseTensor(
7270
features, coordinates, requires_grad=tensor.requires_grad, device=tensor.device
7371
)
@@ -77,11 +75,21 @@ def torch_sparse_to_minkowski(tensor: Tensor):
7775
def minkowski_to_torch_sparse(
7876
tensor: Union[Tensor, ME.SparseTensor],
7977
full_scale_spatial_shape: Optional[Union[Tensor, list[int]]] = None,
78+
squeeze: bool = False
8079
) -> Tensor:
8180
"""Converts a MinkowskiEngine SparseTensor to an equivalent sparse torch.Tensor
8281
8382
Args:
8483
tensor (MinkowskiEngine.SparseTensor): Sparse tensor to be converted
84+
full_scale_spatial_shape (Optional[Union[list[int], [Tensor]]]): The full
85+
extent of the spatial domain on which the sparse data reside.
86+
If given, will be used to define the size of the sparse tensor. If not
87+
given, the size will be inferred from the indices in the tensor.
88+
Default: None
89+
squeeze (bool): If True and the feature dimension of the MinkowskiEngine
90+
SparseTensor is 1, the returned sparse torch.Tensor will have its values
91+
squeezed to 1D shape of [nnz] rather than [nnz, 1]. Raises an error if
92+
True and the feature dimension is not 1.
8593
8694
Returns:
8795
tensor (torch.Tensor): Converted sparse tensor
@@ -102,6 +110,18 @@ def minkowski_to_torch_sparse(
102110
else:
103111
max_coords = None
104112
out = __me_sparse(tensor, min_coords, max_coords)[0].coalesce()
113+
if squeeze:
114+
if out.values().shape[1] != 1:
115+
raise ValueError(
116+
"Got `squeeze`=True, but the MinkowskiEngine tensor has a feature "
117+
f"dim of {out.values().shape[1]}, not 1."
118+
)
119+
out = torch.sparse_coo_tensor(
120+
out.indices(),
121+
out.values().squeeze(-1),
122+
out.shape[:-1],
123+
is_coalesced=out.is_coalesced()
124+
)
105125
return out
106126

107127

tests/test_conversion.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def test_large_sparse_array(self):
133133

134134

135135
@pytest.mark.skipif(not has_minkowskiengine, reason="MinkowskiEngine not installed")
136-
class TestMinkowskiEngineReal:
137-
def test_torch_sparse_to_minkowski_real(self):
136+
class TestMinkowskiEngine:
137+
def test_torch_sparse_to_minkowski(self):
138138
# Create sparse tensor with 3D coordinates (batch, x, y) and features
139139
indices = torch.tensor(
140140
[[0, 0, 1, 1], [0, 1, 0, 1], [0, 0, 1, 1], [3, 4, 5, 6]], dtype=torch.int
@@ -145,12 +145,11 @@ def test_torch_sparse_to_minkowski_real(self):
145145

146146
result = torch_sparse_to_minkowski(tensor)
147147

148-
assert hasattr(result, "F") # Features
149-
assert hasattr(result, "C") # Coordinates
148+
assert isinstance(result, ME.SparseTensor)
150149
assert result.F.shape[0] == 4 # Number of points
151150
assert result.F.shape[1] == 1 # Feature dimension
152151

153-
def test_minkowski_to_torch_sparse_real(self):
152+
def test_minkowski_to_torch_sparse(self):
154153
# Create a MinkowskiEngine SparseTensor
155154
coordinates = torch.tensor(
156155
[[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]], dtype=torch.int
@@ -164,23 +163,53 @@ def test_minkowski_to_torch_sparse_real(self):
164163
assert result.is_sparse
165164
assert result.values().numel() == 4
166165

167-
def test_roundtrip_minkowski(self):
168-
# Test roundtrip conversion
166+
def test_already_torch_sparse(self):
169167
indices = torch.tensor(
170-
[[0, 0, 1], [0, 1, 0], [0, 1, 1], [2, 3, 4]], dtype=torch.int
168+
[[0, 0, 1, 1], [0, 1, 0, 1], [0, 0, 1, 1], [3, 4, 5, 6]], dtype=torch.int
171169
)
170+
values = torch.tensor([1.0, 2.0, 3.0, 4.0])
171+
shape = (2, 2, 2, 10) # batch_size=2, spatial=(2,2), features=10
172+
tensor = torch.sparse_coo_tensor(indices, values, shape).coalesce()
173+
174+
result = minkowski_to_torch_sparse(tensor)
175+
176+
assert result is tensor
177+
178+
def test_roundtrip_minkowski(self):
179+
# Test roundtrip conversion
180+
indices = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 1, 1]], dtype=torch.int).T
172181
values = torch.tensor([1.0, 2.0, 3.0])
173-
shape = (2, 2, 2, 5)
182+
shape = (2, 2, 2)
174183
tensor = torch.sparse_coo_tensor(indices, values, shape).coalesce()
175184

176185
me_tensor = torch_sparse_to_minkowski(tensor)
177186
back_to_torch = minkowski_to_torch_sparse(
178-
me_tensor, full_scale_spatial_shape=[2, 2]
187+
me_tensor, full_scale_spatial_shape=[2, 2], squeeze=True
179188
)
180189

181-
# Check that we get back similar structure (may not be identical due to coordinate handling)
182190
assert back_to_torch.is_sparse
183-
assert back_to_torch.shape[0] == 2 # batch size preserved
191+
assert back_to_torch.shape[0] == 2
192+
assert torch.equal(back_to_torch.indices(), tensor.indices())
193+
assert torch.equal(back_to_torch.values(), tensor.values())
194+
195+
# Test with tensor spatial shape
196+
back_to_torch_2 = minkowski_to_torch_sparse(
197+
me_tensor, full_scale_spatial_shape=torch.tensor([2, 2]), squeeze=True
198+
)
199+
200+
assert torch.equal(back_to_torch.indices(), back_to_torch_2.indices())
201+
assert torch.equal(back_to_torch.values(), back_to_torch_2.values())
202+
203+
def test_squeeze_error(self):
204+
# Test trying to squeeze without scalar features
205+
coordinates = torch.tensor(
206+
[[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]], dtype=torch.int
207+
)
208+
features = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]])
209+
sparse_tensor = ME.SparseTensor(features, coordinates)
210+
211+
with pytest.raises(ValueError, match="Got `squeeze`=True"):
212+
_ = minkowski_to_torch_sparse(sparse_tensor, squeeze=True)
184213

185214

186215
@pytest.mark.skipif(not has_spconv, reason="spconv not installed")
@@ -244,7 +273,6 @@ def test_spconv_squeeze_error(self):
244273
with pytest.raises(ValueError, match="Got `squeeze`=True, but"):
245274
_ = spconv_to_torch_sparse(sparse_conv_tensor, squeeze=True)
246275

247-
248276
def test_roundtrip_spconv(self):
249277
# Test roundtrip conversion
250278
indices = torch.tensor([[0, 0, 1], [0, 1, 0], [0, 1, 1], [0, 0, 1]]).T

0 commit comments

Comments
 (0)