Skip to content

Commit 1609bf4

Browse files
authored
Merge pull request #361 from ev-br/torch_take_along_axis_neg_idx
ENH: torch: allow negative indices in `take` and `take_along_axis`
2 parents a6f5c3f + 4355ab8 commit 1609bf4

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,11 +815,24 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje
815815
if x.ndim != 1:
816816
raise ValueError("axis must be specified when ndim > 1")
817817
axis = 0
818-
return torch.index_select(x, axis, indices, **kwargs)
818+
# torch does not support negative indices,
819+
# see https://github.com/pytorch/pytorch/issues/146211
820+
return torch.index_select(
821+
x,
822+
axis,
823+
torch.where(indices < 0, indices + x.shape[axis], indices),
824+
**kwargs
825+
)
819826

820827

821828
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
822-
return torch.take_along_dim(x, indices, dim=axis)
829+
# torch does not support negative indices,
830+
# see https://github.com/pytorch/pytorch/issues/146211
831+
return torch.take_along_dim(
832+
x,
833+
torch.where(indices < 0, indices + x.shape[axis], indices),
834+
dim=axis
835+
)
823836

824837

825838
def sign(x: Array, /) -> Array:

0 commit comments

Comments
 (0)