Skip to content

Commit ff94c46

Browse files
author
Vincent Moens
committed
[BugFix] Fix parsing integer batch size in AOT
ghstack-source-id: ffd60b7 Pull Request resolved: #1004
1 parent 1659518 commit ff94c46

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

tensordict/_td.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,7 +2054,7 @@ def _parse_batch_size(
20542054
source: T | dict | None,
20552055
batch_size: Sequence[int] | torch.Size | int | None = None,
20562056
) -> torch.Size:
2057-
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."
2057+
ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source."
20582058

20592059
if is_dynamo_compiling():
20602060
if isinstance(batch_size, torch.Size):
@@ -2065,22 +2065,22 @@ def _parse_batch_size(
20652065
return torch.Size(tuple(batch_size))
20662066
if batch_size is None:
20672067
return torch.Size([])
2068-
elif isinstance(batch_size, Number):
2068+
elif isinstance(batch_size, (Number, torch.SymInt)):
20692069
return torch.Size([batch_size])
20702070
elif isinstance(source, TensorDictBase):
20712071
return source.batch_size
2072-
raise ValueError()
2072+
raise ValueError(ERR.format(batch_size))
20732073

20742074
try:
20752075
return torch.Size(batch_size)
20762076
except Exception:
20772077
if batch_size is None:
20782078
return torch.Size([])
2079-
elif isinstance(batch_size, Number):
2079+
elif isinstance(batch_size, (Number, torch.SymInt)):
20802080
return torch.Size([batch_size])
20812081
elif isinstance(source, TensorDictBase):
20822082
return source.batch_size
2083-
raise ValueError(ERR)
2083+
raise ValueError(ERR.format(batch_size))
20842084

20852085
@property
20862086
def batch_dims(self) -> int:

test/test_compile.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -801,26 +801,65 @@ def call(x, td):
801801

802802

803803
@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
804+
@pytest.mark.parametrize("strict", [True, False])
804805
class TestExport:
805-
def test_export_module(self):
806+
def test_export_module(self, strict):
806807
torch._dynamo.reset_code_caches()
807808
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
808809
x = torch.randn(3)
809810
y = torch.randn(3)
810-
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
811+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
811812
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()
812813

813-
def test_export_seq(self):
814+
def test_export_seq(self, strict):
814815
torch._dynamo.reset_code_caches()
815816
tdm = Seq(
816817
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
817818
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
818819
)
819820
x = torch.randn(3)
820821
y = torch.randn(3)
821-
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
822+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
822823
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))
823824

825+
@pytest.mark.parametrize(
826+
"same_shape,dymanic_shape", [[True, True], [True, False], [False, True]]
827+
)
828+
def test_td_output(self, strict, same_shape, dymanic_shape):
829+
# This will only work when the tensordict is pytree-able
830+
class Test(torch.nn.Module):
831+
def forward(self, x: torch.Tensor, y: torch.Tensor):
832+
return TensorDict(
833+
{
834+
"x": x,
835+
"y": y,
836+
},
837+
batch_size=x.shape[0],
838+
)
839+
840+
test = Test()
841+
if same_shape:
842+
x, y = torch.zeros(5, 100), torch.zeros(5, 100)
843+
else:
844+
x, y = torch.zeros(2, 100), torch.zeros(2, 100)
845+
if dymanic_shape:
846+
kwargs = {
847+
"dynamic_shapes": {
848+
"x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
849+
"y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
850+
}
851+
}
852+
else:
853+
kwargs = {}
854+
855+
result = torch.export.export(test, args=(x, y), strict=False, **kwargs)
856+
export_mod = result.module()
857+
x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100)
858+
export_test = export_mod(x_new, y_new)
859+
eager_test = test(x_new, y_new)
860+
assert eager_test.batch_size == export_test.batch_size
861+
assert (export_test == eager_test).all()
862+
824863

825864
@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
826865
class TestONNXExport:

0 commit comments

Comments
 (0)