From 4cd95f3b31c1322fee415efd7c7488c89ac4c1c7 Mon Sep 17 00:00:00 2001 From: Asher Spector Date: Wed, 24 Apr 2019 16:34:39 -0400 Subject: [PATCH 1/3] Added index_add_ --- namedtensor/test_core.py | 14 ++++++++++++++ namedtensor/torch_base.py | 15 +++++++++++++++ namedtensor/torch_helpers.py | 1 + 3 files changed, 30 insertions(+) diff --git a/namedtensor/test_core.py b/namedtensor/test_core.py index 1e31022..e745309 100644 --- a/namedtensor/test_core.py +++ b/namedtensor/test_core.py @@ -424,6 +424,20 @@ def test_takes(): ) +def test_index_add(): + + base = ntorch.zeros(5, 3, names = ("alpha", "beta")) + vals = ntorch.randn(3, 3, names = ("time", "beta")) + index = ntorch.randint(0, 5, (3,), names = ("time")) + + result = base.index_add_('alpha', index, vals) + expected_result = base.values.index_add_(0, index.values, vals.values) + + assert (result._tensor == expected_result).all() + assert result.shape == OrderedDict( + [("alpha", 5), ("beta", 3)] + ) + def test_narrow(): base1 = torch.randn(10, 2, 50) diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index 7601bd2..ca31a08 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -228,6 +228,21 @@ def index_select(self, dim, index): new_names, ) + + @staticmethod + def index_add_(self, dim, index, tensor): + """Accumulate the elements of 'tensor' into the self tensor + by adding to the indices in the order given in 'index'.""" + name = dim + dim = self._schema.get(name) + tensor_names = [n for n in tensor._schema._names if n in index._schema._names] + tensor_names += [n for n in tensor._schema._names if n not in index._schema._names] + self._tensor.index_add_( + dim, index._tensor, tensor._force_order(tensor_names)._tensor + ) + return self + + @staticmethod def index_fill_(self, dim, index, val): "Index into dimension names with the `index` named tensors." diff --git a/namedtensor/torch_helpers.py b/namedtensor/torch_helpers.py index a868c30..8ffa472 100644 --- a/namedtensor/torch_helpers.py +++ b/namedtensor/torch_helpers.py @@ -457,6 +457,7 @@ def __dir__(self): "masked_scatter", "masked_fill_", "index_select", + "index_add_", "index_copy_", "index_fill_", "topk", From 9d806848519ba9f01d7793804083123fd0663fbb Mon Sep 17 00:00:00 2001 From: Asher Spector Date: Wed, 24 Apr 2019 16:35:55 -0400 Subject: [PATCH 2/3] Code formatting for index_add_ --- namedtensor/test_core.py | 13 ++++++------- namedtensor/torch_base.py | 12 +++++++----- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/namedtensor/test_core.py b/namedtensor/test_core.py index e745309..6104ef5 100644 --- a/namedtensor/test_core.py +++ b/namedtensor/test_core.py @@ -426,17 +426,16 @@ def test_takes(): def test_index_add(): - base = ntorch.zeros(5, 3, names = ("alpha", "beta")) - vals = ntorch.randn(3, 3, names = ("time", "beta")) - index = ntorch.randint(0, 5, (3,), names = ("time")) + base = ntorch.zeros(5, 3, names=("alpha", "beta")) + vals = ntorch.randn(3, 3, names=("time", "beta")) + index = ntorch.randint(0, 5, (3,), names=("time")) - result = base.index_add_('alpha', index, vals) + result = base.index_add_("alpha", index, vals) expected_result = base.values.index_add_(0, index.values, vals.values) assert (result._tensor == expected_result).all() - assert result.shape == OrderedDict( - [("alpha", 5), ("beta", 3)] - ) + assert result.shape == OrderedDict([("alpha", 5), ("beta", 3)]) + def test_narrow(): base1 = torch.randn(10, 2, 50) diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index ca31a08..8988021 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -228,21 +228,23 @@ def index_select(self, dim, index): new_names, ) - @staticmethod def index_add_(self, dim, index, tensor): """Accumulate the elements of 'tensor' into the self tensor by adding to the indices in the order given in 'index'.""" name = dim dim = self._schema.get(name) - tensor_names = [n for n in tensor._schema._names if n in index._schema._names] - tensor_names += [n for n in tensor._schema._names if n not in index._schema._names] + tensor_names = [ + n for n in tensor._schema._names if n in index._schema._names + ] + tensor_names += [ + n for n in tensor._schema._names if n not in index._schema._names + ] self._tensor.index_add_( - dim, index._tensor, tensor._force_order(tensor_names)._tensor + dim, index._tensor, tensor._force_order(tensor_names)._tensor ) return self - @staticmethod def index_fill_(self, dim, index, val): "Index into dimension names with the `index` named tensors." From a9ada81e2c53871d8cd510bdaba49d16b24e5683 Mon Sep 17 00:00:00 2001 From: Asher Spector Date: Wed, 24 Apr 2019 17:25:58 -0400 Subject: [PATCH 3/3] Index_add_ dim orders no longer must be consistent --- namedtensor/test_core.py | 10 +++++----- namedtensor/torch_base.py | 8 ++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/namedtensor/test_core.py b/namedtensor/test_core.py index 6104ef5..2fed7c3 100644 --- a/namedtensor/test_core.py +++ b/namedtensor/test_core.py @@ -426,15 +426,15 @@ def test_takes(): def test_index_add(): - base = ntorch.zeros(5, 3, names=("alpha", "beta")) - vals = ntorch.randn(3, 3, names=("time", "beta")) - index = ntorch.randint(0, 5, (3,), names=("time")) + base = ntorch.zeros(3, 5, names=("beta", "alpha")) + vals = ntorch.randn(3, 4, names=("beta", "time")) + index = ntorch.randint(0, 5, (4,), names=("time")) result = base.index_add_("alpha", index, vals) - expected_result = base.values.index_add_(0, index.values, vals.values) + expected_result = base.values.index_add_(1, index.values, vals.values) assert (result._tensor == expected_result).all() - assert result.shape == OrderedDict([("alpha", 5), ("beta", 3)]) + assert result.shape == OrderedDict([("beta", 3), ("alpha", 5)]) def test_narrow(): diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index 8988021..8f18e1b 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -234,12 +234,8 @@ def index_add_(self, dim, index, tensor): by adding to the indices in the order given in 'index'.""" name = dim dim = self._schema.get(name) - tensor_names = [ - n for n in tensor._schema._names if n in index._schema._names - ] - tensor_names += [ - n for n in tensor._schema._names if n not in index._schema._names - ] + tensor_names = [i for i in self._schema._names] + tensor_names[dim] = index._schema._names[0] self._tensor.index_add_( dim, index._tensor, tensor._force_order(tensor_names)._tensor )