Skip to content

Commit ea602f4

Browse files
committed
add tests
1 parent f51766e commit ea602f4

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

tests/test_minkowski_spconv.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import torch
2+
import pytest
3+
4+
from pytorch_sparse_utils.conversion import (
5+
torch_sparse_to_minkowski,
6+
torch_sparse_to_spconv,
7+
)
8+
from pytorch_sparse_utils.imports import has_minkowskiengine, ME, has_spconv, spconv
9+
10+
from pytorch_sparse_utils.minkowskiengine import (
11+
MinkowskiGELU,
12+
MinkowskiLayerNorm,
13+
get_me_layer,
14+
MinkowskiNonlinearityBase,
15+
)
16+
from pytorch_sparse_utils.spconv import spconv_sparse_mult
17+
18+
19+
@pytest.mark.skipif(not has_minkowskiengine, reason="MinkowskiEngine not installed")
20+
@pytest.mark.cpu_and_cuda
21+
class TestMinkowskiEngineUtils:
22+
def test_minkowski_layer_norm(self, device):
23+
indices = torch.tensor([[0, 0], [0, 1]], device=device).T
24+
values = torch.randn(2, 8, device=device)
25+
tensor = torch.sparse_coo_tensor(indices, values).coalesce()
26+
27+
me_tensor = torch_sparse_to_minkowski(tensor)
28+
29+
norm = MinkowskiLayerNorm(8).to(device)
30+
31+
out = norm(me_tensor)
32+
assert isinstance(out, ME.SparseTensor)
33+
assert not torch.equal(me_tensor.F, out.F)
34+
assert torch.equal(me_tensor.C, out.C)
35+
36+
me_tensor_field = ME.TensorField(me_tensor.F, me_tensor.C)
37+
38+
out_2 = norm(me_tensor_field)
39+
assert isinstance(out_2, ME.TensorField)
40+
assert not torch.equal(me_tensor_field.F, out_2.F)
41+
assert torch.equal(me_tensor_field.C, out_2.C)
42+
43+
def test_minkowski_gelu(self, device):
44+
indices = torch.tensor([[0, 0], [0, 1]], device=device).T
45+
values = torch.randn(2, 8, device=device)
46+
tensor = torch.sparse_coo_tensor(indices, values).coalesce()
47+
48+
me_tensor = torch_sparse_to_minkowski(tensor)
49+
50+
gelu = MinkowskiGELU()
51+
assert isinstance(gelu, MinkowskiNonlinearityBase)
52+
53+
out = gelu(me_tensor)
54+
55+
assert isinstance(out, ME.SparseTensor)
56+
assert not torch.equal(me_tensor.F, out.F)
57+
assert torch.equal(me_tensor.C, out.C)
58+
59+
def test_get_me_layer(self):
60+
module = get_me_layer(MinkowskiGELU()) # pyright: ignore[reportArgumentType]
61+
assert isinstance(module, MinkowskiGELU)
62+
63+
relu = get_me_layer("relu")
64+
assert isinstance(relu(), ME.MinkowskiReLU)
65+
66+
gelu = get_me_layer("gelu")
67+
assert isinstance(gelu(), MinkowskiGELU)
68+
69+
bn = get_me_layer("batchnorm1d")
70+
assert isinstance(
71+
bn(8), ME.MinkowskiBatchNorm # pyright: ignore[reportCallIssue]
72+
)
73+
74+
with pytest.raises(ValueError, match="Unexpected layer"):
75+
get_me_layer("fdsfdsf")
76+
77+
78+
@pytest.mark.skipif(not has_spconv, reason="spconv not installed")
79+
@pytest.mark.cpu_and_cuda
80+
class TestSpConvUtils:
81+
def test_spconv_sparse_mult(self, device):
82+
indices = torch.tensor([[0, 0], [0, 1]], device=device).T
83+
values = torch.randn(2, 8, device=device)
84+
tensor = torch.sparse_coo_tensor(indices, values).coalesce()
85+
86+
spconv_tensor = torch_sparse_to_spconv(tensor)
87+
assert isinstance(spconv_tensor, spconv.SparseConvTensor)
88+
89+
out = spconv_sparse_mult(spconv_tensor, spconv_tensor)
90+
91+
assert torch.equal(
92+
out.indices, spconv_tensor.indices # pyright: ignore[reportArgumentType]
93+
)
94+
assert not torch.equal(out.features, spconv_tensor.features)
95+
96+
def test_spconv_sparse_mult_different_indices(self, device):
97+
indices = torch.tensor([[0, 0], [1, 1], [1, 2]], device=device).T
98+
values = torch.randn(3, 8, device=device)
99+
tensor = torch.sparse_coo_tensor(indices, values, (2, 5, 8)).coalesce()
100+
101+
spconv_tensor_1 = torch_sparse_to_spconv(tensor)
102+
103+
indices = torch.tensor([[1, 0], [0, 1], [0, 0]], device=device).T
104+
values = torch.randn(3, 8, device=device)
105+
tensor = torch.sparse_coo_tensor(indices, values, (2, 5, 8)).coalesce()
106+
107+
spconv_tensor_2 = torch_sparse_to_spconv(tensor)
108+
109+
assert not torch.equal(
110+
spconv_tensor_1.indices, # pyright: ignore[reportArgumentType]
111+
spconv_tensor_2.indices, # pyright: ignore[reportArgumentType]
112+
)
113+
114+
out = spconv_sparse_mult(spconv_tensor_1, spconv_tensor_2)
115+
116+
assert not torch.equal(
117+
out.features,
118+
spconv_tensor_1.features, # pyright: ignore[reportArgumentType]
119+
)
120+
assert not torch.equal(
121+
out.features,
122+
spconv_tensor_2.features, # pyright: ignore[reportArgumentType]
123+
)

0 commit comments

Comments
 (0)