Skip to content

Commit 610c582

Browse files
committed
add tests for unique_rows
1 parent 0579a8c commit 610c582

File tree

2 files changed

+104
-10
lines changed

2 files changed

+104
-10
lines changed

pytorch_sparse_utils/indexing/unique.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def unique_rows(tensor: Tensor, sorted: bool = True) -> Tensor:
1010
Args:
1111
tensor (Tensor): A 2D tensor of integer type.
1212
sorted (bool): Whether to sort the indices of unique rows before returning.
13-
If False, returned indices will be in lexicographic order.
13+
If False, returned indices will be in lexicographic order of the rows.
1414
1515
Returns:
1616
Tensor: A 1D tensor whose elements are the indices of the unique rows of
@@ -43,17 +43,15 @@ def unique_rows(tensor: Tensor, sorted: bool = True) -> Tensor:
4343
max_vals = tensor.max(0).values
4444
min_vals = tensor.min(0).values
4545

46-
# Handle negative values by shifting to nonnegative
47-
has_negs = min_vals < 0
48-
if has_negs.any():
49-
# Shift each column to be nonnegative
50-
neg_shift = torch.where(has_negs, min_vals, min_vals.new_zeros([]))
51-
tensor = tensor - neg_shift
52-
max_vals = max_vals - neg_shift
53-
5446
# Check for overflow problems
55-
log_sum = (max_vals + 1).log().sum()
5647
INT64_MAX = 9223372036854775807
48+
if (max_vals >= INT64_MAX).any():
49+
raise OverflowError(
50+
f"Tensor contains values at or near maximum int64 value ({INT64_MAX}), "
51+
"which would lead to overflow errors when computing unique rows."
52+
)
53+
54+
log_sum = (max_vals + 1).log().sum()
5755
log_max = torch.tensor(INT64_MAX, device=max_vals.device).log()
5856

5957
if log_sum > log_max:
@@ -62,6 +60,14 @@ def unique_rows(tensor: Tensor, sorted: bool = True) -> Tensor:
6260
f"approx {log_sum.exp()} compared to max int64 value of {INT64_MAX}."
6361
)
6462

63+
# Handle negative values by shifting to nonnegative
64+
has_negs = min_vals < 0
65+
if has_negs.any():
66+
# Shift each column to be nonnegative
67+
neg_shift = torch.where(has_negs, min_vals, min_vals.new_zeros([]))
68+
tensor = tensor - neg_shift
69+
max_vals = max_vals - neg_shift
70+
6571
tensor_flat, _ = flatten_nd_indices(tensor.T.long(), max_vals)
6672
tensor_flat: Tensor = tensor_flat.squeeze(0)
6773

tests/indexing/test_misc.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import pytest
3+
4+
from pytorch_sparse_utils.indexing.unique import unique_rows
5+
from pytorch_sparse_utils.indexing.scatter import scatter_to_sparse_tensor
6+
7+
@pytest.mark.cpu_and_cuda
8+
class TestUniqueRows:
9+
def test_basic_functionality(self, device):
10+
tensor = torch.tensor([
11+
[1, 2, 3],
12+
[1, 2, 3],
13+
[7, 8, 9],
14+
[4, 5, 6],
15+
[4, 5, 6],
16+
], device=device
17+
)
18+
unique_inds = unique_rows(tensor)
19+
20+
assert torch.equal(unique_inds, torch.tensor([0, 2, 3], device=device))
21+
22+
def test_negative_values(self, device):
23+
tensor = torch.tensor(
24+
[
25+
[1, 2, 3],
26+
[-1, -2, -3],
27+
[-1, -2, -3],
28+
[4, -10, 3],
29+
], device=device
30+
)
31+
unique_inds = unique_rows(tensor)
32+
33+
assert torch.equal(unique_inds, torch.tensor([0, 1, 3], device=device))
34+
35+
def test_sorted(self, device):
36+
tensor = torch.tensor([
37+
[-1, -3, 5],
38+
[3, 20, 44],
39+
[1, 2, 3],
40+
[-1, -3, 5],
41+
[3, 20, 44]
42+
], device=device)
43+
44+
unique_unsorted = unique_rows(tensor, sorted=False)
45+
unique_sorted = unique_rows(tensor, sorted=True)
46+
47+
assert not torch.equal(unique_sorted, unique_unsorted)
48+
assert torch.equal(unique_sorted, torch.tensor([0, 1, 2], device=device))
49+
assert torch.equal(unique_unsorted, torch.tensor([0, 2, 1], device=device))
50+
51+
def test_error_wrong_dim(self, device):
52+
tensor = torch.randint(0, 100, size=(10,), device=device)
53+
with pytest.raises(
54+
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
55+
match="Expected a 2D tensor"
56+
):
57+
unique_rows(tensor)
58+
59+
def test_error_not_int(self, device):
60+
tensor_float = torch.randn(10, 10, device=device)
61+
with pytest.raises(
62+
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
63+
match="Expected integer tensor"
64+
):
65+
unique_rows(tensor_float)
66+
67+
tensor_complex = torch.randn(10, 10, device=device, dtype=torch.complex64)
68+
with pytest.raises(
69+
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
70+
match="Expected integer tensor"
71+
):
72+
unique_rows(tensor_complex)
73+
74+
def test_error_overflow(self, device):
75+
tensor = torch.randint(-100, 100, size=(10, 4), device=device, dtype=torch.long)
76+
tensor[0, :] = torch.iinfo(torch.long).max
77+
with pytest.raises(
78+
(OverflowError, torch.jit.Error), # pyright: ignore[reportArgumentType]
79+
match="Tensor contains values at or near"
80+
):
81+
unique_rows(tensor)
82+
83+
tensor[0, :] = torch.iinfo(torch.long).max - 100
84+
with pytest.raises(
85+
(OverflowError, torch.jit.Error), # pyright: ignore[reportArgumentType]
86+
match="would cause integer overflow"
87+
):
88+
unique_rows(tensor)

0 commit comments

Comments
 (0)