Skip to content

[BUG] TensorClass with non_tensor lead to torch.compile failing (RecursionError or compiling forever) #1336

Closed
@mctigger

Description

@mctigger

Describe the bug

When a @tensorclass object is created inside a compile function and the tensorclass has fields that are not a torch.Tensor (enforced through either being of a type that cannot be cast to tensor or with nocast=True), torch.compile will run forever or throw an error. The code below explains the issue the best.

To Reproduce

Note: The cases may have different failure behavior as stated in the code. With tensordict-nightly==02025.2.3 I produced the one below. For tensordict 0.8.3 there is different behavior, but still wrong behavior (e.g. compiling forever).

import torch
from tensordict import TensorClass
from torch import Tensor


# Same issue with the decorator @tensorclass version
class TensorClassWithNonTensorData(TensorClass["nocast"]):
    tensor: Tensor
    non_tensor_data: int


data = torch.ones(4, 5, device="cuda")


def fn_no_device_no_batch_size():
    a = TensorClassWithNonTensorData(tensor=data, non_tensor_data=1)
    return a.tensor


def fn_no_device():
    a = TensorClassWithNonTensorData(tensor=data, non_tensor_data=1, batch_size=[4])
    return a.tensor


def fn_with_device():
    a = TensorClassWithNonTensorData(
        tensor=data, non_tensor_data=1, batch_size=[4], device="cuda"
    )
    return a.tensor


def fn_with_device_without_batch_size():
    a = TensorClassWithNonTensorData(tensor=data, non_tensor_data=1, device="cuda")
    return a.tensor


mode = "with_device_without_batch_size"

match mode:
    case "no_device_no_batch_size":
        # This will print
        print("no_device_no_batch_size", torch.compile(fn_no_device_no_batch_size)())
    case "no_device":
        # This throws an error. Interestingly, if the previous case is executed before,
        # then this doesn't throw an error!
        print("no_device", torch.compile(fn_no_device)())
    case "with_device":
        # This throws an error! Interestinfly, if the first case is executed before,
        # this does not throw an error, but compiles forever.
        print("with_device", torch.compile(fn_with_device)())
    case "with_device_without_batch_size":
        # This compiles forever, no matter what"
        print(
            "with_device_without_batch_size",
            torch.compile(fn_with_device_without_batch_size)(),
        )

This is the error if an error is thrown (for some cases it compiles forever, see above).

torch._dynamo.exc.InternalTorchDynamoError: RecursionError: maximum recursion depth exceeded while calling a Python object

from user code:
   File "/opt/conda/lib/python3.10/site-packages/tensordict/_td.py", line 2453, in _set_str
    value = self._validate_value(
  File "/opt/conda/lib/python3.10/site-packages/tensordict/base.py", line 11108, in _validate_value
    value.batch_size = self.batch_size
  File "/opt/conda/lib/python3.10/site-packages/tensordict/tensorclass.py", line 1530, in wrapper
    return setattr_(self, key, value)
  File "/opt/conda/lib/python3.10/site-packages/tensordict/tensorclass.py", line 1530, in wrapper
    return setattr_(self, key, value)
  File "/opt/conda/lib/python3.10/site-packages/tensordict/tensorclass.py", line 1530, in wrapper
    return setattr_(self, key, value)
  [Previous line repeated 64 more times]

Expected behavior

It should compile.

System info

Describe the characteristic of your environment:

  • Python 3.10.16
  • tensordict-nightly==02025.2.3
  • torch==2.6.0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions