diff --git a/namedtensor/test_core.py b/namedtensor/test_core.py index 1e31022..2fed7c3 100644 --- a/namedtensor/test_core.py +++ b/namedtensor/test_core.py @@ -424,6 +424,19 @@ def test_takes(): ) +def test_index_add(): + + base = ntorch.zeros(3, 5, names=("beta", "alpha")) + vals = ntorch.randn(3, 4, names=("beta", "time")) + index = ntorch.randint(0, 5, (4,), names=("time")) + + result = base.index_add_("alpha", index, vals) + expected_result = base.values.index_add_(1, index.values, vals.values) + + assert (result._tensor == expected_result).all() + assert result.shape == OrderedDict([("beta", 3), ("alpha", 5)]) + + def test_narrow(): base1 = torch.randn(10, 2, 50) diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index 7601bd2..8f18e1b 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -228,6 +228,19 @@ def index_select(self, dim, index): new_names, ) + @staticmethod + def index_add_(self, dim, index, tensor): + """Accumulate the elements of 'tensor' into the self tensor + by adding to the indices in the order given in 'index'.""" + name = dim + dim = self._schema.get(name) + tensor_names = [i for i in self._schema._names] + tensor_names[dim] = index._schema._names[0] + self._tensor.index_add_( + dim, index._tensor, tensor._force_order(tensor_names)._tensor + ) + return self + @staticmethod def index_fill_(self, dim, index, val): "Index into dimension names with the `index` named tensors." diff --git a/namedtensor/torch_helpers.py b/namedtensor/torch_helpers.py index a868c30..8ffa472 100644 --- a/namedtensor/torch_helpers.py +++ b/namedtensor/torch_helpers.py @@ -457,6 +457,7 @@ def __dir__(self): "masked_scatter", "masked_fill_", "index_select", + "index_add_", "index_copy_", "index_fill_", "topk",