|
| 1 | +import torch |
| 2 | +import pytest |
| 3 | + |
| 4 | +from pytorch_sparse_utils.indexing.unique import unique_rows |
| 5 | +from pytorch_sparse_utils.indexing.scatter import scatter_to_sparse_tensor |
| 6 | + |
| 7 | +@pytest.mark.cpu_and_cuda |
| 8 | +class TestUniqueRows: |
| 9 | + def test_basic_functionality(self, device): |
| 10 | + tensor = torch.tensor([ |
| 11 | + [1, 2, 3], |
| 12 | + [1, 2, 3], |
| 13 | + [7, 8, 9], |
| 14 | + [4, 5, 6], |
| 15 | + [4, 5, 6], |
| 16 | + ], device=device |
| 17 | + ) |
| 18 | + unique_inds = unique_rows(tensor) |
| 19 | + |
| 20 | + assert torch.equal(unique_inds, torch.tensor([0, 2, 3], device=device)) |
| 21 | + |
| 22 | + def test_negative_values(self, device): |
| 23 | + tensor = torch.tensor( |
| 24 | + [ |
| 25 | + [1, 2, 3], |
| 26 | + [-1, -2, -3], |
| 27 | + [-1, -2, -3], |
| 28 | + [4, -10, 3], |
| 29 | + ], device=device |
| 30 | + ) |
| 31 | + unique_inds = unique_rows(tensor) |
| 32 | + |
| 33 | + assert torch.equal(unique_inds, torch.tensor([0, 1, 3], device=device)) |
| 34 | + |
| 35 | + def test_sorted(self, device): |
| 36 | + tensor = torch.tensor([ |
| 37 | + [-1, -3, 5], |
| 38 | + [3, 20, 44], |
| 39 | + [1, 2, 3], |
| 40 | + [-1, -3, 5], |
| 41 | + [3, 20, 44] |
| 42 | + ], device=device) |
| 43 | + |
| 44 | + unique_unsorted = unique_rows(tensor, sorted=False) |
| 45 | + unique_sorted = unique_rows(tensor, sorted=True) |
| 46 | + |
| 47 | + assert not torch.equal(unique_sorted, unique_unsorted) |
| 48 | + assert torch.equal(unique_sorted, torch.tensor([0, 1, 2], device=device)) |
| 49 | + assert torch.equal(unique_unsorted, torch.tensor([0, 2, 1], device=device)) |
| 50 | + |
| 51 | + def test_error_wrong_dim(self, device): |
| 52 | + tensor = torch.randint(0, 100, size=(10,), device=device) |
| 53 | + with pytest.raises( |
| 54 | + (ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 55 | + match="Expected a 2D tensor" |
| 56 | + ): |
| 57 | + unique_rows(tensor) |
| 58 | + |
| 59 | + def test_error_not_int(self, device): |
| 60 | + tensor_float = torch.randn(10, 10, device=device) |
| 61 | + with pytest.raises( |
| 62 | + (ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 63 | + match="Expected integer tensor" |
| 64 | + ): |
| 65 | + unique_rows(tensor_float) |
| 66 | + |
| 67 | + tensor_complex = torch.randn(10, 10, device=device, dtype=torch.complex64) |
| 68 | + with pytest.raises( |
| 69 | + (ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 70 | + match="Expected integer tensor" |
| 71 | + ): |
| 72 | + unique_rows(tensor_complex) |
| 73 | + |
| 74 | + def test_error_overflow(self, device): |
| 75 | + tensor = torch.randint(-100, 100, size=(10, 4), device=device, dtype=torch.long) |
| 76 | + tensor[0, :] = torch.iinfo(torch.long).max |
| 77 | + with pytest.raises( |
| 78 | + (OverflowError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 79 | + match="Tensor contains values at or near" |
| 80 | + ): |
| 81 | + unique_rows(tensor) |
| 82 | + |
| 83 | + tensor[0, :] = torch.iinfo(torch.long).max - 100 |
| 84 | + with pytest.raises( |
| 85 | + (OverflowError, torch.jit.Error), # pyright: ignore[reportArgumentType] |
| 86 | + match="would cause integer overflow" |
| 87 | + ): |
| 88 | + unique_rows(tensor) |
0 commit comments