4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import argparse
7
+ import inspect
7
8
8
9
import pytest
9
10
import torch
10
11
import torch .nn as nn
11
12
12
13
from tensordict import TensorDict
13
- from tensordict .nn import TensorDictModule , TensorDictSequential
14
+ from tensordict .nn import TensorDictModule as Mod , TensorDictSequential as Seq
14
15
from tensordict .prototype .fx import symbolic_trace
15
16
16
17
18
+ def test_fx ():
19
+ seq = Seq (
20
+ Mod (lambda x : x + 1 , in_keys = ["x" ], out_keys = ["y" ]),
21
+ Mod (lambda x , y : (x * y ).sqrt (), in_keys = ["x" , "y" ], out_keys = ["z" ]),
22
+ Mod (lambda z , x : z - z , in_keys = ["z" , "x" ], out_keys = ["a" ]),
23
+ )
24
+ symbolic_trace (seq )
25
+
26
+
27
+ class TestModule (torch .nn .Module ):
28
+ def __init__ (self ):
29
+ super ().__init__ ()
30
+ self .linear = torch .nn .Linear (2 , 2 )
31
+
32
+ def forward (self , td : TensorDict ) -> torch .Tensor :
33
+ vals = td .values () # pyre-ignore[6]
34
+ return torch .cat ([val ._values for val in vals ], dim = 0 )
35
+
36
+
37
+ def test_td_scripting () -> None :
38
+ for cls in (TensorDict ,):
39
+ for name in dir (cls ):
40
+ method = inspect .getattr_static (cls , name )
41
+ if isinstance (method , classmethod ):
42
+ continue
43
+ elif isinstance (method , staticmethod ):
44
+ continue
45
+ elif not callable (method ):
46
+ continue
47
+ elif not name .startswith ("__" ) or name in ("__init__" , "__setitem__" ):
48
+ setattr (cls , name , torch .jit .unused (method ))
49
+
50
+ m = TestModule ()
51
+ td = TensorDict (
52
+ a = torch .nested .nested_tensor ([torch .ones ((1 ,))], layout = torch .jagged )
53
+ )
54
+ m (td )
55
+ m = torch .jit .script (m , example_inputs = (td ,))
56
+ m .code
57
+
58
+
17
59
def test_tensordictmodule_trace_consistency ():
18
60
class Net (nn .Module ):
19
61
def __init__ (self ):
@@ -24,7 +66,7 @@ def forward(self, x):
24
66
logits = self .linear (x )
25
67
return logits , torch .sigmoid (logits )
26
68
27
- module = TensorDictModule (
69
+ module = Mod (
28
70
Net (),
29
71
in_keys = ["input" ],
30
72
out_keys = [("outputs" , "logits" ), ("outputs" , "probabilities" )],
@@ -63,15 +105,13 @@ class Masker(nn.Module):
63
105
def forward (self , x , mask ):
64
106
return torch .softmax (x * mask , dim = 1 )
65
107
66
- net = TensorDictModule (
67
- Net (), in_keys = [("input" , "x" )], out_keys = [("intermediate" , "x" )]
68
- )
69
- masker = TensorDictModule (
108
+ net = Mod (Net (), in_keys = [("input" , "x" )], out_keys = [("intermediate" , "x" )])
109
+ masker = Mod (
70
110
Masker (),
71
111
in_keys = [("intermediate" , "x" ), ("input" , "mask" )],
72
112
out_keys = [("output" , "probabilities" )],
73
113
)
74
- module = TensorDictSequential (net , masker )
114
+ module = Seq (net , masker )
75
115
graph_module = symbolic_trace (module )
76
116
77
117
tensordict = TensorDict (
@@ -120,13 +160,11 @@ def forward(self, x):
120
160
module2 = Net (50 , 40 )
121
161
module3 = Output (40 , 10 )
122
162
123
- tdmodule1 = TensorDictModule (module1 , ["input" ], ["x" ])
124
- tdmodule2 = TensorDictModule (module2 , ["x" ], ["x" ])
125
- tdmodule3 = TensorDictModule (module3 , ["x" ], ["probabilities" ])
163
+ tdmodule1 = Mod (module1 , ["input" ], ["x" ])
164
+ tdmodule2 = Mod (module2 , ["x" ], ["x" ])
165
+ tdmodule3 = Mod (module3 , ["x" ], ["probabilities" ])
126
166
127
- tdmodule = TensorDictSequential (
128
- TensorDictSequential (tdmodule1 , tdmodule2 ), tdmodule3
129
- )
167
+ tdmodule = Seq (Seq (tdmodule1 , tdmodule2 ), tdmodule3 )
130
168
graph_module = symbolic_trace (tdmodule )
131
169
132
170
tensordict = TensorDict ({"input" : torch .rand (32 , 100 )}, [32 ])
0 commit comments