Closed
Description
Describe the bug
When a @tensorclass
object is created inside a compile function and the tensorclass has fields that are not a torch.Tensor
(enforced through either being of a type that cannot be cast to tensor or with nocast=True
), torch.compile will run forever or throw an error. The code below explains the issue the best.
To Reproduce
Note: The cases may have different failure behavior as stated in the code. With tensordict-nightly==02025.2.3 I produced the one below. For tensordict 0.8.3 there is different behavior, but still wrong behavior (e.g. compiling forever).
import torch
from tensordict import TensorClass
from torch import Tensor
# Same issue with the decorator @tensorclass version
class TensorClassWithNonTensorData(TensorClass["nocast"]):
tensor: Tensor
non_tensor_data: int
data = torch.ones(4, 5, device="cuda")
def fn_no_device_no_batch_size():
a = TensorClassWithNonTensorData(tensor=data, non_tensor_data=1)
return a.tensor
def fn_no_device():
a = TensorClassWithNonTensorData(tensor=data, non_tensor_data=1, batch_size=[4])
return a.tensor
def fn_with_device():
a = TensorClassWithNonTensorData(
tensor=data, non_tensor_data=1, batch_size=[4], device="cuda"
)
return a.tensor
def fn_with_device_without_batch_size():
a = TensorClassWithNonTensorData(tensor=data, non_tensor_data=1, device="cuda")
return a.tensor
mode = "with_device_without_batch_size"
match mode:
case "no_device_no_batch_size":
# This will print
print("no_device_no_batch_size", torch.compile(fn_no_device_no_batch_size)())
case "no_device":
# This throws an error. Interestingly, if the previous case is executed before,
# then this doesn't throw an error!
print("no_device", torch.compile(fn_no_device)())
case "with_device":
# This throws an error! Interestinfly, if the first case is executed before,
# this does not throw an error, but compiles forever.
print("with_device", torch.compile(fn_with_device)())
case "with_device_without_batch_size":
# This compiles forever, no matter what"
print(
"with_device_without_batch_size",
torch.compile(fn_with_device_without_batch_size)(),
)
This is the error if an error is thrown (for some cases it compiles forever, see above).
torch._dynamo.exc.InternalTorchDynamoError: RecursionError: maximum recursion depth exceeded while calling a Python object
from user code:
File "/opt/conda/lib/python3.10/site-packages/tensordict/_td.py", line 2453, in _set_str
value = self._validate_value(
File "/opt/conda/lib/python3.10/site-packages/tensordict/base.py", line 11108, in _validate_value
value.batch_size = self.batch_size
File "/opt/conda/lib/python3.10/site-packages/tensordict/tensorclass.py", line 1530, in wrapper
return setattr_(self, key, value)
File "/opt/conda/lib/python3.10/site-packages/tensordict/tensorclass.py", line 1530, in wrapper
return setattr_(self, key, value)
File "/opt/conda/lib/python3.10/site-packages/tensordict/tensorclass.py", line 1530, in wrapper
return setattr_(self, key, value)
[Previous line repeated 64 more times]
Expected behavior
It should compile.
System info
Describe the characteristic of your environment:
- Python 3.10.16
- tensordict-nightly==02025.2.3
- torch==2.6.0
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)