Skip to content

Commit f00d363

Browse files
author
Vincent Moens
committed
[BugFix] TorchScript compat
ghstack-source-id: a2bd5ea Pull Request resolved: #1141
1 parent 2aea3dd commit f00d363

File tree

3 files changed

+55
-17
lines changed

3 files changed

+55
-17
lines changed

tensordict/_td.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,7 @@ def _apply_nest(
13341334
filter_empty: bool | None = None,
13351335
is_leaf: Callable | None = None,
13361336
out: TensorDictBase | None = None,
1337-
**constructor_kwargs,
1337+
**constructor_kwargs: Any,
13381338
) -> T | None:
13391339
if inplace:
13401340
result = self

tensordict/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8393,7 +8393,7 @@ def func(*args, **kwargs):
83938393

83948394
def map(
83958395
self,
8396-
fn: Callable[[TensorDictBase], TensorDictBase | None],
8396+
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
83978397
dim: int = 0,
83988398
num_workers: int | None = None,
83998399
*,
@@ -8573,7 +8573,7 @@ def map(
85738573

85748574
def map_iter(
85758575
self,
8576-
fn: Callable[[TensorDictBase], TensorDictBase | None],
8576+
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
85778577
dim: int = 0,
85788578
num_workers: int | None = None,
85798579
*,
@@ -8771,7 +8771,7 @@ def map_iter(
87718771

87728772
def _map(
87738773
self,
8774-
fn: Callable[[TensorDictBase], TensorDictBase | None],
8774+
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
87758775
dim: int = 0,
87768776
*,
87778777
shuffle: bool = False,

test/test_fx.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,58 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
import inspect
78

89
import pytest
910
import torch
1011
import torch.nn as nn
1112

1213
from tensordict import TensorDict
13-
from tensordict.nn import TensorDictModule, TensorDictSequential
14+
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
1415
from tensordict.prototype.fx import symbolic_trace
1516

1617

18+
def test_fx():
19+
seq = Seq(
20+
Mod(lambda x: x + 1, in_keys=["x"], out_keys=["y"]),
21+
Mod(lambda x, y: (x * y).sqrt(), in_keys=["x", "y"], out_keys=["z"]),
22+
Mod(lambda z, x: z - z, in_keys=["z", "x"], out_keys=["a"]),
23+
)
24+
symbolic_trace(seq)
25+
26+
27+
class TestModule(torch.nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
self.linear = torch.nn.Linear(2, 2)
31+
32+
def forward(self, td: TensorDict) -> torch.Tensor:
33+
vals = td.values() # pyre-ignore[6]
34+
return torch.cat([val._values for val in vals], dim=0)
35+
36+
37+
def test_td_scripting() -> None:
38+
for cls in (TensorDict,):
39+
for name in dir(cls):
40+
method = inspect.getattr_static(cls, name)
41+
if isinstance(method, classmethod):
42+
continue
43+
elif isinstance(method, staticmethod):
44+
continue
45+
elif not callable(method):
46+
continue
47+
elif not name.startswith("__") or name in ("__init__", "__setitem__"):
48+
setattr(cls, name, torch.jit.unused(method))
49+
50+
m = TestModule()
51+
td = TensorDict(
52+
a=torch.nested.nested_tensor([torch.ones((1,))], layout=torch.jagged)
53+
)
54+
m(td)
55+
m = torch.jit.script(m, example_inputs=(td,))
56+
m.code
57+
58+
1759
def test_tensordictmodule_trace_consistency():
1860
class Net(nn.Module):
1961
def __init__(self):
@@ -24,7 +66,7 @@ def forward(self, x):
2466
logits = self.linear(x)
2567
return logits, torch.sigmoid(logits)
2668

27-
module = TensorDictModule(
69+
module = Mod(
2870
Net(),
2971
in_keys=["input"],
3072
out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
@@ -63,15 +105,13 @@ class Masker(nn.Module):
63105
def forward(self, x, mask):
64106
return torch.softmax(x * mask, dim=1)
65107

66-
net = TensorDictModule(
67-
Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
68-
)
69-
masker = TensorDictModule(
108+
net = Mod(Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")])
109+
masker = Mod(
70110
Masker(),
71111
in_keys=[("intermediate", "x"), ("input", "mask")],
72112
out_keys=[("output", "probabilities")],
73113
)
74-
module = TensorDictSequential(net, masker)
114+
module = Seq(net, masker)
75115
graph_module = symbolic_trace(module)
76116

77117
tensordict = TensorDict(
@@ -120,13 +160,11 @@ def forward(self, x):
120160
module2 = Net(50, 40)
121161
module3 = Output(40, 10)
122162

123-
tdmodule1 = TensorDictModule(module1, ["input"], ["x"])
124-
tdmodule2 = TensorDictModule(module2, ["x"], ["x"])
125-
tdmodule3 = TensorDictModule(module3, ["x"], ["probabilities"])
163+
tdmodule1 = Mod(module1, ["input"], ["x"])
164+
tdmodule2 = Mod(module2, ["x"], ["x"])
165+
tdmodule3 = Mod(module3, ["x"], ["probabilities"])
126166

127-
tdmodule = TensorDictSequential(
128-
TensorDictSequential(tdmodule1, tdmodule2), tdmodule3
129-
)
167+
tdmodule = Seq(Seq(tdmodule1, tdmodule2), tdmodule3)
130168
graph_module = symbolic_trace(tdmodule)
131169

132170
tensordict = TensorDict({"input": torch.rand(32, 100)}, [32])

0 commit comments

Comments
 (0)