Skip to content

Added proposed multi_index_select version #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions namedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def __init__(self, tensor, names, mask=0):
"Tensor has %d dim, but only %d names"
% (len(self._tensor.shape), len(self._schema._names))
)
for name in self._schema._names:
assert name.isalnum(), (
"dim name %s must be alphanumeric" % name
)

@property
def dims(self):
Expand Down
110 changes: 104 additions & 6 deletions namedtensor/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ def make_tensors(sizes, names):
return [ntorch.randn(sizes, names=names)]


def test_names():
base = torch.zeros([10, 2, 50])
assert ntorch.tensor(base, ("alpha", "beta", "gamma"))


@pytest.mark.xfail
def test_old_nonzero_names():
base = torch.zeros([10, 2])
assert ntorch.tensor(base, ("elements_dim", "input_dims"))


def test_shift():
for ntensor in make_tensors((10, 2, 50), ("alpha", "beta", "gamma")):
# Split
Expand Down Expand Up @@ -287,9 +298,9 @@ def test_nonzero():
# only zeros
x = ntorch.zeros(10, names=("alpha",))
y = x.nonzero()
assert 0 == y.size("elements_dim")
assert 0 == y.size("elementsdim")
assert x.shape == OrderedDict([("alpha", 10)])
assert y.shape == OrderedDict([("elements_dim", 0), ("input_dims", 1)])
assert y.shape == OrderedDict([("elementsdim", 0), ("inputdims", 1)])

# `names` length must be 2
y = x.nonzero(names=("a", "b"))
Expand All @@ -299,9 +310,9 @@ def test_nonzero():
# 1d tensor
x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
y = x.nonzero()
assert 3 == y.size("elements_dim")
assert 3 == y.size("elementsdim")
assert x.shape == OrderedDict([("dim", 5)])
assert y.shape == OrderedDict([("elements_dim", 3), ("input_dims", 1)])
assert y.shape == OrderedDict([("elementsdim", 3), ("inputdims", 1)])

# `names` length must be 2
y = x.nonzero(names=("a", "b"))
Expand All @@ -319,9 +330,9 @@ def test_nonzero():
names=("alpha", "beta"),
)
y = x.nonzero()
assert 5 == y.size("elements_dim")
assert 5 == y.size("elementsdim")
assert x.shape == OrderedDict([("alpha", 4), ("beta", 4)])
assert y.shape == OrderedDict([("elements_dim", 5), ("input_dims", 2)])
assert y.shape == OrderedDict([("elementsdim", 5), ("inputdims", 2)])

# `names` length must be 2
y = x.nonzero(names=("a", "b"))
Expand All @@ -343,6 +354,93 @@ def test_nonzero_names():
assert 2 == len(y.shape)


def test_multi_index_select():

def _check_output(tensor, dims, indices, output):
names = tensor._schema._names
index_names = indices._schema._names

output_names = output._schema._names
assert len(names) - len(dims) + 1 == len(output_names)

input_elements = indices.shape[index_names[0]]
output_elements = output.shape[index_names[0]]
assert input_elements == output_elements

remaining_dims = set(names) - set(dims)
for name in remaining_dims:
assert name in output_names

output_element_dims = []
remaining_element_dims = []
for name in output_names[1:]:
output_element_dims.append(output.shape[name])
remaining_element_dims.append(tensor.shape[name])
assert output_element_dims == remaining_element_dims

# 1d tensor, nonzero test
tensor = ntorch.tensor([0.6, 0.4, 0.0], names=('alpha',))
indices = tensor.nonzero()
dims = ('alpha',)
selected_values = tensor.multi_index_select(dims, indices)
_check_output(tensor, dims, indices, selected_values)

# 3d tensor
base = torch.cat([torch.tensor([[[0.6, 0.4, 0.0],
[2.0, 0.0, 1.2]]])] * 4, 0)
tensor = ntorch.tensor(base, names=('alpha', 'beta', 'gamma'))

# nonzero test
indices = tensor.nonzero()
dims = ('alpha', 'beta', 'gamma')
selected_values = tensor.multi_index_select(dims, indices)
_check_output(tensor, dims, indices, selected_values)

# one dimension
indices = ntorch.tensor(torch.tensor([[0], [1], [1]]),
names=('elementsdim', 'inputdims'))
dims = ('gamma',)
selected_values = tensor.multi_index_select(dims, indices)
_check_output(tensor, dims, indices, selected_values)

# one dimension
indices = ntorch.tensor([[1], [2]],
names=('elementsdim', 'inputdims'))
dims = ('alpha',)
selected_values = tensor.multi_index_select(dims, indices)
_check_output(tensor, dims, indices, selected_values)

# two transposed dimensions
indices = ntorch.tensor(torch.tensor([[0, 0], [0, 1]]),
names=('elementsdim', 'inputdims'))
dims = ('gamma', 'beta')
selected_values = tensor.multi_index_select(dims, indices)
_check_output(tensor, dims, indices, selected_values)

# 4d tensor
base = torch.tensor([[0.6, 0.0, 0.0],
[0.0, 0.4, 0.0],
[0.0, 0.0, 1.2],
[2.0, 0.0, 0.9]])
base = torch.cat([base.unsqueeze(0)] * 5, 0)
base = torch.cat([base.unsqueeze(0)] * 7, 0)
tensor = ntorch.tensor(base, names=('dim0', 'dim1', 'dim2', 'dim3'))

# nonzero test
indices = tensor.nonzero()
dims = ('dim0', 'dim1', 'dim2', 'dim3')
selected_values = tensor.multi_index_select(dims, indices)
_check_output(tensor, dims, indices, selected_values)

indices = ntorch.tensor(indices.values[:, :2],
names=('elements', 'indims'))
dims = ('dim0', 'dim1')

# two dimensions
selected_values = tensor.multi_index_select(dims, indices)
_check_output(tensor, dims, indices, selected_values)


# def test_scalar():
# base1 = ntorch.randn(dict(alpha=10, beta=2, gamma=50))
# base2 = base1 + 10
63 changes: 59 additions & 4 deletions namedtensor/torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def masked_select(input, mask, name):
return NamedTensor(a1.values.masked_select(b1.values), name)

@staticmethod
def nonzero(tensor, names=("elements_dim", "input_dims")):
def nonzero(tensor, names=("elementsdim", "inputdims")):
"""
Returns a tensor containing the indices of all non-zero elements.

Expand All @@ -106,14 +106,69 @@ def nonzero(tensor, names=("elements_dim", "input_dims")):
tensor: NamedTensor
names : tuple, optional
Names for the output dimensions
default value: ("elements_dim", "input_dims")
default output shape: OrderedDict([("elements_dim", number of non-zero elements),
("input_dims", input tensor's number of dimensions)])
default value: ("elementsdim", "inputdims")
default output shape:
OrderedDict([("elementsdim", number of non-zero elements),
("inputdims", input tensor's number of dimensions)])
"""

indices = torch.nonzero(tensor.values)
return NamedTensor(tensor=indices, names=names)

@staticmethod
def multi_index_select(tensor, dims, indices):
indices_names = indices._schema._names
index_dim = indices_names[1]
if len(dims) != indices.shape[index_dim]:
raise RuntimeError(
"Size of elements in 'indices' should be %d, got %d"
% (len(dims), indices.shape[index_dim])
)
if len(tensor.shape) < len(dims):
raise RuntimeError(
"Size of 'dims' must be <= tensor dims (%d), got %d"
% (len(tensor.shape), len(dims))
)
if len(set(dims)) < len(dims):
raise RuntimeError("Tuple 'dims' must contain unique names")
names = tensor._schema._names
for dim in dims:
if dim not in names:
raise RuntimeError("%s is not a dimension name in tensor" % dim)

values = tensor.values
names = tensor._schema._names

# find names index in dims
match_dims = []
for dim in dims:
dim_idx = names.index(dim)
match_dims.append(dim_idx)

# find remaining tensor dims
remaining_dims = []
remaining_names = []
for i, name in enumerate(names):
if i not in match_dims:
remaining_dims.append(i)
remaining_names.append(name)

# permute tensor values to match dims
permute_idx = match_dims + remaining_dims
values = values.permute(*permute_idx)

# find values by idx element in indices
tensors = []
for idx in indices.values:
indexed_value = values[tuple(idx)].unsqueeze(0)
tensors.append(indexed_value)
tensors = torch.cat(tensors)

elements_dim = indices_names[0]
new_names = tuple([elements_dim] + remaining_names)
selecte_values = ntorch.tensor(tensors, names=new_names)
return selecte_values

@staticmethod
def scatter_(input, dim, index, src, index_dim):
indim = dim
Expand Down
15 changes: 11 additions & 4 deletions namedtensor/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,24 @@ def masked_select(self, mask, name):

return ntorch.masked_select(self, mask, name)

def nonzero(self, names=("elements_dim", "input_dims")):
def multi_index_select(self, dims, indices):
"Index into dims names with the `indices` named tensors."
from .torch_base import ntorch

return ntorch.multi_index_select(self, dims, indices)

def nonzero(self, names=("elementsdim", "inputdims")):
"""
Returns a tensor containing the indices of all non-zero elements.

Parameters
----------
names : tuple, optional
Names for the output dimensions
default value: ("elements_dim", "input_dims")
default output shape: OrderedDict([("elements_dim", number of non-zero elements),
("input_dims", input tensor's number of dimensions)])
default value: ("elementsdim", "inputdims")
default output shape:
OrderedDict([("elementsdim", number of non-zero elements),
("inputdims", input tensor's number of dimensions)])
"""

from .torch_base import ntorch
Expand Down