Skip to content

Adding expand_dims for xtensor #1449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
expand_dims,
join,
moveaxis,
specify_shape,
Expand All @@ -10,6 +11,7 @@
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
Expand Down Expand Up @@ -121,7 +123,7 @@ def lower_transpose(fgraph, node):

@register_lower_xtensor
@node_rewriter([Squeeze])
def local_squeeze_reshape(fgraph, node):
def lower_squeeze(fgraph, node):
"""Rewrite Squeeze to tensor.squeeze."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
Expand All @@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node):

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([ExpandDims])
def lower_expand_dims(fgraph, node):
"""Rewrite ExpandDims using tensor operations."""
x, size = node.inputs
out = node.outputs[0]

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
size_tensor = tensor_from_xtensor(size)

# Get the new dimension name and position
new_axis = 0 # Always insert at front

# Use tensor operations
if out.type.shape[0] == 1:
# Simple case: just expand with size 1
result_tensor = expand_dims(x_tensor, new_axis)
else:
# Otherwise broadcast to the requested size
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))

# Preserve static shape information
result_tensor = specify_shape(result_tensor, out.type.shape)

# Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result]
123 changes: 122 additions & 1 deletion pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import warnings
from collections.abc import Sequence
from collections.abc import Hashable, Sequence
from types import EllipsisType
from typing import Literal

import numpy as np

from pytensor.graph import Apply
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor

Expand Down Expand Up @@ -380,3 +383,121 @@ def squeeze(x, dim=None):
return x # no-op if nothing to squeeze

return Squeeze(dims=dims)(x)


class ExpandDims(XOp):
"""Add a new dimension to an XTensorVariable."""

__props__ = ("dim",)

def __init__(self, dim):
if not isinstance(dim, str):
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")

self.dim = dim

def make_node(self, x, size):
x = as_xtensor(x)

if self.dim in x.type.dims:
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")

size = as_xtensor(size, dims=())
if not (size.dtype in integer_dtypes and size.ndim == 0):
raise ValueError(f"size should be an integer scalar, got {size.type}")
try:
static_size = int(get_scalar_constant_value(size))
except NotScalarConstantError:
static_size = None

# If size is a constant, validate it
if static_size is not None and static_size < 0:
raise ValueError(f"size must be 0 or positive, got: {static_size}")
new_shape = (static_size, *x.type.shape)

# Insert new dim at front
new_dims = (self.dim, *x.type.dims)

out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x, size], [out])


def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable."""
x = as_xtensor(x)

# Store original dimensions for axis handling
original_dims = x.type.dims

# Warn if create_index_for_new_dim is used (not supported)
if create_index_for_new_dim is not None:
warnings.warn(
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
UserWarning,
stacklevel=2,
)

if dim is None:
dim = dim_kwargs
elif dim_kwargs:
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")

# Check that dim is Hashable or a sequence of Hashable or dict
if not isinstance(dim, Hashable):
if not isinstance(dim, Sequence | dict):
raise TypeError(f"unhashable type: {type(dim).__name__}")
if not all(isinstance(d, Hashable) for d in dim):
raise TypeError(f"unhashable type in {type(dim).__name__}")

# Normalize to a dimension-size mapping
if isinstance(dim, str):
dims_dict = {dim: 1}
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
dims_dict = {d: 1 for d in dim}
elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, str):
raise TypeError(f"Dimension size cannot be a string: {val}")
if isinstance(val, Sequence | np.ndarray):
warnings.warn(
"When a sequence is provided as a dimension size, only its length is used. "
"The actual values (which would be coordinates in xarray) are ignored.",
UserWarning,
stacklevel=2,
)
dims_dict[name] = len(val)
else:
# should be int or symbolic scalar
dims_dict[name] = val
else:
raise TypeError(f"Invalid type for `dim`: {type(dim)}")

# Insert each new dim at the front (reverse order preserves user intent)
for name, size in reversed(dims_dict.items()):
x = ExpandDims(dim=name)(x, size)

# If axis is specified, transpose to put new dimensions in the right place
if axis is not None:
# Wrap non-sequence axis in a list
if not isinstance(axis, Sequence):
axis = [axis]

# require len(axis) == len(dims_dict)
if len(axis) != len(dims_dict):
raise ValueError("lengths of dim and axis should be identical.")

# Insert new dimensions at their specified positions
target_dims = list(original_dims)
for name, pos in zip(dims_dict, axis):
# Convert negative axis to positive position relative to current dims
if pos < 0:
pos = len(target_dims) + pos + 1
target_dims.insert(pos, name)
x = Transpose(dims=tuple(target_dims))(x)

return x
41 changes: 41 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,47 @@ def squeeze(
raise NotImplementedError("Squeeze with axis not Implemented")
return px.shape.squeeze(self, dim)

def expand_dims(
self,
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
create_index_for_new_dim: bool = True,
axis: int | Sequence[int] | None = None,
**dim_kwargs,
):
"""Add one or more new dimensions to the tensor.

Parameters
----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.

Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
return px.shape.expand_dims(
self,
dim,
create_index_for_new_dim=create_index_for_new_dim,
axis=axis,
**dim_kwargs,
)

# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
Expand Down
112 changes: 111 additions & 1 deletion tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from itertools import chain, combinations

import numpy as np
import pytest
from xarray import DataArray
from xarray import concat as xr_concat

from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
concat,
squeeze,
Expand Down Expand Up @@ -369,3 +369,113 @@ def test_squeeze_errors():
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)


def test_expand_dims():
"""Test expand_dims."""
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
x_test = xr_arange_like(x)

# Implicit size 1
y = x.expand_dims("country")
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))

# Test with multiple dimensions
y = x.expand_dims(["country", "state"])
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))

# Test with a dict of name-size pairs
y = x.expand_dims({"country": 2, "state": 3})
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))

# Test with kwargs (equivalent to dict)
y = x.expand_dims(country=2, state=3)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))

# Test with a dict of name-coord array pairs
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test),
x_test.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}),
)

# Symbolic size 1
size_sym_1 = scalar("size_sym_1", dtype="int64")
y = x.expand_dims({"country": size_sym_1})
fn = xr_function([x, size_sym_1], y)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1}))

# Test with symbolic sizes in dict
size_sym_2 = scalar("size_sym_2", dtype="int64")
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))

# Test with symbolic sizes in kwargs
y = x.expand_dims(country=size_sym_1, state=size_sym_2)
fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))

# Test with axis parameter
y = x.expand_dims("country", axis=1)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))

# Test with negative axis parameter
y = x.expand_dims("country", axis=-1)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1))

# Add two new dims with axis parameters
y = x.expand_dims(["country", "state"], axis=[1, 2])
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
)

# Add two dims with negative axis parameters
y = x.expand_dims(["country", "state"], axis=[-1, -2])
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-1, -2])
)

# Add two dims with positive and negative axis parameters
y = x.expand_dims(["country", "state"], axis=[-2, 1])
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-2, 1])
)


def test_expand_dims_errors():
"""Test error handling in expand_dims."""

# Expanding existing dim
x = xtensor("x", dims=("city",), shape=(3,))
y = x.expand_dims("country")
with pytest.raises(ValueError, match="already exists"):
y.expand_dims("city")

# Invalid dim type
with pytest.raises(TypeError, match="Invalid type for `dim`"):
x.expand_dims(123)

# Duplicate dimension creation
y = x.expand_dims("new")
with pytest.raises(ValueError, match="already exists"):
y.expand_dims("new")

# Find out what xarray does with a numpy array as dim
# x_test = xr_arange_like(x)
# x_test.expand_dims(np.array([1, 2]))
# TypeError: unhashable type: 'numpy.ndarray'

# Test with a numpy array as dim (not supported)
with pytest.raises(TypeError, match="unhashable type"):
y.expand_dims(np.array([1, 2]))