Skip to content

Commit 5270c64

Browse files
author
Vincent Moens
committed
[BugFix] Fix parsing integer batch size in AOT
ghstack-source-id: 18a5798 Pull Request resolved: #1004
1 parent 85b6b81 commit 5270c64

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

tensordict/_td.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,7 +2061,7 @@ def _parse_batch_size(
20612061
source: T | dict | None,
20622062
batch_size: Sequence[int] | torch.Size | int | None = None,
20632063
) -> torch.Size:
2064-
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."
2064+
ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source."
20652065

20662066
if is_dynamo_compiling():
20672067
if isinstance(batch_size, torch.Size):
@@ -2072,22 +2072,22 @@ def _parse_batch_size(
20722072
return torch.Size(tuple(batch_size))
20732073
if batch_size is None:
20742074
return torch.Size([])
2075-
elif isinstance(batch_size, Number):
2075+
elif isinstance(batch_size, (Number, torch.SymInt)):
20762076
return torch.Size([batch_size])
20772077
elif isinstance(source, TensorDictBase):
20782078
return source.batch_size
2079-
raise ValueError()
2079+
raise ValueError(ERR.format(batch_size))
20802080

20812081
try:
20822082
return torch.Size(batch_size)
20832083
except Exception:
20842084
if batch_size is None:
20852085
return torch.Size([])
2086-
elif isinstance(batch_size, Number):
2086+
elif isinstance(batch_size, (Number, torch.SymInt)):
20872087
return torch.Size([batch_size])
20882088
elif isinstance(source, TensorDictBase):
20892089
return source.batch_size
2090-
raise ValueError(ERR)
2090+
raise ValueError(ERR.format(batch_size))
20912091

20922092
@property
20932093
def batch_dims(self) -> int:

test/test_compile.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -774,26 +774,65 @@ def call(x, td):
774774

775775

776776
@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
777+
@pytest.mark.parametrize("strict", [True, False])
777778
class TestExport:
778-
def test_export_module(self):
779+
def test_export_module(self, strict):
779780
torch._dynamo.reset_code_caches()
780781
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
781782
x = torch.randn(3)
782783
y = torch.randn(3)
783-
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
784+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
784785
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()
785786

786-
def test_export_seq(self):
787+
def test_export_seq(self, strict):
787788
torch._dynamo.reset_code_caches()
788789
tdm = Seq(
789790
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
790791
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
791792
)
792793
x = torch.randn(3)
793794
y = torch.randn(3)
794-
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
795+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
795796
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))
796797

798+
@pytest.mark.parametrize(
799+
"same_shape,dymanic_shape", [[True, True], [True, False], [False, True]]
800+
)
801+
def test_td_output(self, strict, same_shape, dymanic_shape):
802+
# This will only work when the tensordict is pytree-able
803+
class Test(torch.nn.Module):
804+
def forward(self, x: torch.Tensor, y: torch.Tensor):
805+
return TensorDict(
806+
{
807+
"x": x,
808+
"y": y,
809+
},
810+
batch_size=x.shape[0],
811+
)
812+
813+
test = Test()
814+
if same_shape:
815+
x, y = torch.zeros(5, 100), torch.zeros(5, 100)
816+
else:
817+
x, y = torch.zeros(2, 100), torch.zeros(2, 100)
818+
if dymanic_shape:
819+
kwargs = {
820+
"dynamic_shapes": {
821+
"x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
822+
"y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
823+
}
824+
}
825+
else:
826+
kwargs = {}
827+
828+
result = torch.export.export(test, args=(x, y), strict=False, **kwargs)
829+
export_mod = result.module()
830+
x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100)
831+
export_test = export_mod(x_new, y_new)
832+
eager_test = test(x_new, y_new)
833+
assert eager_test.batch_size == export_test.batch_size
834+
assert (export_test == eager_test).all()
835+
797836

798837
@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
799838
class TestONNXExport:

0 commit comments

Comments
 (0)