Skip to content

Commit 8baa4eb

Browse files
committed
bugfix
1 parent 41b0d65 commit 8baa4eb

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ tests = [
2222
"pytest-env",
2323
"hypothesis",
2424
"pytest-cov",
25+
"pytest-xdist",
2526
]
2627
minkowskiengine = ["MinkowskiEngine"]
2728
spconv = ["spconv"]

pytorch_sparse_utils/utils/batch_topk.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def batch_topk(
8181
each concatenated subsequence. Default: 0 (sequence dimension).
8282
largest (bool, optional): If True, returns the indices of the largest elements.
8383
If False, returns those of the smallest elements. Default: True.
84-
sorted (bool, optional): If True, returns the elements in sorted order.
85-
Default: True.
84+
sorted (bool, optional): If True, always returns the elements in sorted order.
85+
For technical reasons, the returned elements may be sorted in some cases
86+
even when False. Default: True.
8687
return_values (bool, optional): If True, the output namedtuple will include the
8788
topk values in addition to the indices and offsets. Default: False.
8889
@@ -213,8 +214,8 @@ def batch_topk(
213214
topk_dim = dim + 1 # account for new leading batch dim
214215

215216
values_all, indices_all = tensor.reshape(batch_shape).topk(
216-
k_max_int, topk_dim, largest=largest, sorted=sorted
217-
)
217+
k_max_int, topk_dim, largest=largest, sorted=True
218+
) # Need to be sorted to be able to select first k for each subseq
218219

219220
# If topk is along sequence length, need to add offsets to indices
220221
# to globalize them

tests/utils/test_batch_topk.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22

33
import pytest
44
import torch
5-
from hypothesis import HealthCheck, given, settings
5+
from hypothesis import HealthCheck, example, given, settings
66
from hypothesis import strategies as st
77
from torch import Tensor
88

9+
from pytorch_sparse_utils.batching import (
10+
batch_offsets_to_seq_lengths,
11+
seq_lengths_to_batch_offsets,
12+
)
913
from pytorch_sparse_utils.utils import (
10-
batch_topk,
1114
BatchTopK,
15+
batch_topk,
1216
unpack_batch_topk,
1317
)
14-
from pytorch_sparse_utils.batching import (
15-
seq_lengths_to_batch_offsets,
16-
)
1718

1819

1920
# Helper utils
@@ -242,6 +243,17 @@ def test_negative_k_raises(self, device):
242243
batch_topk(t, off, k=-1)
243244

244245
# Property-based test
246+
@example(
247+
params={
248+
"seq_lens": [3, 3],
249+
"extra_dims": [],
250+
"dim": 0,
251+
"k": [1, 3],
252+
"largest": False,
253+
"sorted_": False,
254+
"seed": 0,
255+
},
256+
)
245257
@settings(deadline=None, suppress_health_check=[HealthCheck.differing_executors])
246258
@given(params=batch_topk_inputs())
247259
def test_property(self, params, device):
@@ -275,6 +287,14 @@ def test_property(self, params, device):
275287
else: # tensor
276288
k_per_batch = params["k"].tolist()
277289

290+
# Determine if batch_topk will need to actually sort indices even if
291+
# sorted = False
292+
if not params["sorted_"]:
293+
n_seq_lengths = batch_offsets_to_seq_lengths(offsets).unique()
294+
if n_seq_lengths.numel() == 1:
295+
params["sorted_"] = True
296+
297+
278298
ref_idx, ref_off, ref_vals = topk_reference(
279299
tensor_ref,
280300
offsets,

0 commit comments

Comments
 (0)