diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index b0ca4f3bd4..c0b1a5fe88 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,6 +1,7 @@ from pytensor.graph import node_rewriter from pytensor.tensor import ( broadcast_to, + expand_dims, join, moveaxis, specify_shape, @@ -10,6 +11,7 @@ from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.shape import ( Concat, + ExpandDims, Squeeze, Stack, Transpose, @@ -121,7 +123,7 @@ def lower_transpose(fgraph, node): @register_lower_xtensor @node_rewriter([Squeeze]) -def local_squeeze_reshape(fgraph, node): +def lower_squeeze(fgraph, node): """Rewrite Squeeze to tensor.squeeze.""" [x] = node.inputs x_tensor = tensor_from_xtensor(x) @@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node): new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) return [new_out] + + +@register_lower_xtensor +@node_rewriter([ExpandDims]) +def lower_expand_dims(fgraph, node): + """Rewrite ExpandDims using tensor operations.""" + x, size = node.inputs + out = node.outputs[0] + + # Convert inputs to tensors + x_tensor = tensor_from_xtensor(x) + size_tensor = tensor_from_xtensor(size) + + # Get the new dimension name and position + new_axis = 0 # Always insert at front + + # Use tensor operations + if out.type.shape[0] == 1: + # Simple case: just expand with size 1 + result_tensor = expand_dims(x_tensor, new_axis) + else: + # Otherwise broadcast to the requested size + result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape)) + + # Preserve static shape information + result_tensor = specify_shape(result_tensor, out.type.shape) + + # Convert result back to xtensor + result = xtensor_from_tensor(result_tensor, dims=out.type.dims) + return [result] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index cd0f024e56..f604dc8188 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,12 +1,15 @@ import warnings -from collections.abc import Sequence +from collections.abc import Hashable, Sequence from types import EllipsisType from typing import Literal +import numpy as np + from pytensor.graph import Apply from pytensor.scalar import discrete_dtypes, upcast from pytensor.tensor import as_tensor, get_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.type import integer_dtypes from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import as_xtensor, xtensor @@ -380,3 +383,121 @@ def squeeze(x, dim=None): return x # no-op if nothing to squeeze return Squeeze(dims=dims)(x) + + +class ExpandDims(XOp): + """Add a new dimension to an XTensorVariable.""" + + __props__ = ("dim",) + + def __init__(self, dim): + if not isinstance(dim, str): + raise TypeError(f"`dim` must be a string, got: {type(self.dim)}") + + self.dim = dim + + def make_node(self, x, size): + x = as_xtensor(x) + + if self.dim in x.type.dims: + raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}") + + size = as_xtensor(size, dims=()) + if not (size.dtype in integer_dtypes and size.ndim == 0): + raise ValueError(f"size should be an integer scalar, got {size.type}") + try: + static_size = int(get_scalar_constant_value(size)) + except NotScalarConstantError: + static_size = None + + # If size is a constant, validate it + if static_size is not None and static_size < 0: + raise ValueError(f"size must be 0 or positive, got: {static_size}") + new_shape = (static_size, *x.type.shape) + + # Insert new dim at front + new_dims = (self.dim, *x.type.dims) + + out = xtensor( + dtype=x.type.dtype, + shape=new_shape, + dims=new_dims, + ) + return Apply(self, [x, size], [out]) + + +def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs): + """Add one or more new dimensions to an XTensorVariable.""" + x = as_xtensor(x) + + # Store original dimensions for axis handling + original_dims = x.type.dims + + # Warn if create_index_for_new_dim is used (not supported) + if create_index_for_new_dim is not None: + warnings.warn( + "create_index_for_new_dim=False has no effect in pytensor.xtensor", + UserWarning, + stacklevel=2, + ) + + if dim is None: + dim = dim_kwargs + elif dim_kwargs: + raise ValueError("Cannot specify both `dim` and `**dim_kwargs`") + + # Check that dim is Hashable or a sequence of Hashable or dict + if not isinstance(dim, Hashable): + if not isinstance(dim, Sequence | dict): + raise TypeError(f"unhashable type: {type(dim).__name__}") + if not all(isinstance(d, Hashable) for d in dim): + raise TypeError(f"unhashable type in {type(dim).__name__}") + + # Normalize to a dimension-size mapping + if isinstance(dim, str): + dims_dict = {dim: 1} + elif isinstance(dim, Sequence) and not isinstance(dim, dict): + dims_dict = {d: 1 for d in dim} + elif isinstance(dim, dict): + dims_dict = {} + for name, val in dim.items(): + if isinstance(val, str): + raise TypeError(f"Dimension size cannot be a string: {val}") + if isinstance(val, Sequence | np.ndarray): + warnings.warn( + "When a sequence is provided as a dimension size, only its length is used. " + "The actual values (which would be coordinates in xarray) are ignored.", + UserWarning, + stacklevel=2, + ) + dims_dict[name] = len(val) + else: + # should be int or symbolic scalar + dims_dict[name] = val + else: + raise TypeError(f"Invalid type for `dim`: {type(dim)}") + + # Insert each new dim at the front (reverse order preserves user intent) + for name, size in reversed(dims_dict.items()): + x = ExpandDims(dim=name)(x, size) + + # If axis is specified, transpose to put new dimensions in the right place + if axis is not None: + # Wrap non-sequence axis in a list + if not isinstance(axis, Sequence): + axis = [axis] + + # require len(axis) == len(dims_dict) + if len(axis) != len(dims_dict): + raise ValueError("lengths of dim and axis should be identical.") + + # Insert new dimensions at their specified positions + target_dims = list(original_dims) + for name, pos in zip(dims_dict, axis): + # Convert negative axis to positive position relative to current dims + if pos < 0: + pos = len(target_dims) + pos + 1 + target_dims.insert(pos, name) + x = Transpose(dims=tuple(target_dims))(x) + + return x diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index fd601df018..9fea411129 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -481,6 +481,47 @@ def squeeze( raise NotImplementedError("Squeeze with axis not Implemented") return px.shape.squeeze(self, dim) + def expand_dims( + self, + dim: str | Sequence[str] | dict[str, int | Sequence] | None = None, + create_index_for_new_dim: bool = True, + axis: int | Sequence[int] | None = None, + **dim_kwargs, + ): + """Add one or more new dimensions to the tensor. + + Parameters + ---------- + dim : str | Sequence[str] | dict[str, int | Sequence] | None + If str or sequence of str, new dimensions with size 1. + If dict, keys are dimension names and values are either: + - int: the new size + - sequence: coordinates (length determines size) + create_index_for_new_dim : bool, default: True + Currently ignored. Reserved for future coordinate support. + In xarray, when True (default), creates a coordinate index for the new dimension + with values from 0 to size-1. When False, no coordinate index is created. + axis : int | Sequence[int] | None, default: None + Not implemented yet. In xarray, specifies where to insert the new dimension(s). + By default (None), new dimensions are inserted at the beginning (axis=0). + Symbolic axis is not supported yet. + Negative values count from the end. + **dim_kwargs : int | Sequence + Alternative to `dim` dict. Only used if `dim` is None. + + Returns + ------- + XTensorVariable + A tensor with additional dimensions inserted at the front. + """ + return px.shape.expand_dims( + self, + dim, + create_index_for_new_dim=create_index_for_new_dim, + axis=axis, + **dim_kwargs, + ) + # ndarray methods # https://docs.xarray.dev/en/latest/api.html#id7 def clip(self, min, max): diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index f5db72bf1f..69802dcec0 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -8,10 +8,10 @@ from itertools import chain, combinations import numpy as np -import pytest from xarray import DataArray from xarray import concat as xr_concat +from pytensor.tensor import scalar from pytensor.xtensor.shape import ( concat, squeeze, @@ -369,3 +369,113 @@ def test_squeeze_errors(): fn2 = xr_function([x2], y2) with pytest.raises(Exception): fn2(x2_test) + + +def test_expand_dims(): + """Test expand_dims.""" + x = xtensor("x", dims=("city", "year"), shape=(2, 2)) + x_test = xr_arange_like(x) + + # Implicit size 1 + y = x.expand_dims("country") + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) + + # Test with multiple dimensions + y = x.expand_dims(["country", "state"]) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"])) + + # Test with a dict of name-size pairs + y = x.expand_dims({"country": 2, "state": 3}) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3})) + + # Test with kwargs (equivalent to dict) + y = x.expand_dims(country=2, state=3) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3)) + + # Test with a dict of name-coord array pairs + y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), + x_test.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}), + ) + + # Symbolic size 1 + size_sym_1 = scalar("size_sym_1", dtype="int64") + y = x.expand_dims({"country": size_sym_1}) + fn = xr_function([x, size_sym_1], y) + xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1})) + + # Test with symbolic sizes in dict + size_sym_2 = scalar("size_sym_2", dtype="int64") + y = x.expand_dims({"country": size_sym_1, "state": size_sym_2}) + fn = xr_function([x, size_sym_1, size_sym_2], y) + xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) + + # Test with symbolic sizes in kwargs + y = x.expand_dims(country=size_sym_1, state=size_sym_2) + fn = xr_function([x, size_sym_1, size_sym_2], y) + xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) + + # Test with axis parameter + y = x.expand_dims("country", axis=1) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1)) + + # Test with negative axis parameter + y = x.expand_dims("country", axis=-1) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1)) + + # Add two new dims with axis parameters + y = x.expand_dims(["country", "state"], axis=[1, 2]) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2]) + ) + + # Add two dims with negative axis parameters + y = x.expand_dims(["country", "state"], axis=[-1, -2]) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), x_test.expand_dims(["country", "state"], axis=[-1, -2]) + ) + + # Add two dims with positive and negative axis parameters + y = x.expand_dims(["country", "state"], axis=[-2, 1]) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), x_test.expand_dims(["country", "state"], axis=[-2, 1]) + ) + + +def test_expand_dims_errors(): + """Test error handling in expand_dims.""" + + # Expanding existing dim + x = xtensor("x", dims=("city",), shape=(3,)) + y = x.expand_dims("country") + with pytest.raises(ValueError, match="already exists"): + y.expand_dims("city") + + # Invalid dim type + with pytest.raises(TypeError, match="Invalid type for `dim`"): + x.expand_dims(123) + + # Duplicate dimension creation + y = x.expand_dims("new") + with pytest.raises(ValueError, match="already exists"): + y.expand_dims("new") + + # Find out what xarray does with a numpy array as dim + # x_test = xr_arange_like(x) + # x_test.expand_dims(np.array([1, 2])) + # TypeError: unhashable type: 'numpy.ndarray' + + # Test with a numpy array as dim (not supported) + with pytest.raises(TypeError, match="unhashable type"): + y.expand_dims(np.array([1, 2]))