4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import argparse
7
+ import time
8
+ from typing import Any
7
9
8
10
import pytest
9
11
import torch
10
12
from packaging import version
11
13
12
- from tensordict import TensorDict
14
+ from tensordict import tensorclass , TensorDict
15
+ from tensordict .utils import logger as tensordict_logger
13
16
14
17
TORCH_VERSION = version .parse (version .parse (torch .__version__ ).base_version )
15
18
16
19
17
- @pytest .fixture
18
- def td ():
19
- return TensorDict (
20
- {
21
- str (i ): {str (j ): torch .randn (16 , 16 , device = "cpu" ) for j in range (16 )}
22
- for i in range (16 )
23
- },
24
- batch_size = [16 ],
25
- device = "cpu" ,
26
- )
20
+ @tensorclass
21
+ class NJT :
22
+ _values : torch .Tensor
23
+ _offsets : torch .Tensor
24
+ _lengths : torch .Tensor
25
+ njt_shape : Any = None
26
+
27
+ @classmethod
28
+ def from_njt (cls , njt_tensor ):
29
+ return cls (
30
+ _values = njt_tensor ._values ,
31
+ _offsets = njt_tensor ._offsets ,
32
+ _lengths = njt_tensor ._lengths ,
33
+ njt_shape = njt_tensor .size (0 ),
34
+ ).clone ()
35
+
36
+
37
+ @pytest .fixture (autouse = True , scope = "function" )
38
+ def empty_compiler_cache ():
39
+ torch .compiler .reset ()
40
+ yield
27
41
28
42
29
43
def _make_njt ():
@@ -34,14 +48,29 @@ def _make_njt():
34
48
)
35
49
36
50
37
- @pytest .fixture
38
- def njt_td ():
51
+ def _njt_td ():
39
52
return TensorDict (
40
- {str (i ): {str (j ): _make_njt () for j in range (32 )} for i in range (32 )},
53
+ # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
54
+ {str (i ): _make_njt () for i in range (32 )},
41
55
device = "cpu" ,
42
56
)
43
57
44
58
59
+ @pytest .fixture
60
+ def njt_td ():
61
+ return _njt_td ()
62
+
63
+
64
+ @pytest .fixture
65
+ def td ():
66
+ njtd = _njt_td ()
67
+ for k0 , v0 in njtd .items ():
68
+ njtd [k0 ] = NJT .from_njt (v0 )
69
+ # for k1, v1 in v0.items():
70
+ # njtd[k0, k1] = NJT.from_njt(v1)
71
+ return njtd
72
+
73
+
45
74
@pytest .fixture
46
75
def default_device ():
47
76
if torch .cuda .is_available ():
@@ -52,22 +81,152 @@ def default_device():
52
81
pytest .skip ("CUDA/MPS is not available" )
53
82
54
83
55
- @pytest .mark .parametrize ("consolidated" , [False , True ])
84
+ @pytest .mark .parametrize (
85
+ "compile_mode,num_threads" ,
86
+ [
87
+ [False , None ],
88
+ # [False, 4],
89
+ # [False, 16],
90
+ ["default" , None ],
91
+ ["reduce-overhead" , None ],
92
+ ],
93
+ )
56
94
@pytest .mark .skipif (
57
- TORCH_VERSION < version .parse ("2.5.1 " ), reason = "requires torch>=2.5"
95
+ TORCH_VERSION < version .parse ("2.5.0 " ), reason = "requires torch>=2.5"
58
96
)
97
+ class TestConsolidate :
98
+ def test_consolidate (
99
+ self , benchmark , td , compile_mode , num_threads , default_device
100
+ ):
101
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
102
+
103
+ # td = td.to(default_device)
104
+
105
+ def consolidate (td , num_threads ):
106
+ return td .consolidate (num_threads = num_threads )
107
+
108
+ if compile_mode :
109
+ consolidate = torch .compile (
110
+ consolidate , mode = compile_mode , dynamic = False , fullgraph = True
111
+ )
112
+
113
+ t0 = time .time ()
114
+ consolidate (td , num_threads = num_threads )
115
+ elapsed = time .time () - t0
116
+ tensordict_logger .info (f"elapsed time first call: { elapsed :.2f} sec" )
117
+
118
+ for _ in range (3 ):
119
+ consolidate (td , num_threads = num_threads )
120
+
121
+ benchmark (consolidate , td , num_threads )
122
+
123
+ def test_consolidate_njt (self , benchmark , njt_td , compile_mode , num_threads ):
124
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
125
+
126
+ def consolidate (td , num_threads ):
127
+ return td .consolidate (num_threads = num_threads )
128
+
129
+ if compile_mode :
130
+ pytest .skip (
131
+ "Compiling NJTs consolidation currently triggers a RuntimeError."
132
+ )
133
+ # consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)
134
+
135
+ for _ in range (3 ):
136
+ consolidate (njt_td , num_threads = num_threads )
137
+
138
+ benchmark (consolidate , njt_td , num_threads )
139
+
140
+
141
+ @pytest .mark .parametrize (
142
+ "consolidated,compile_mode,num_threads" ,
143
+ [
144
+ [False , False , None ],
145
+ [True , False , None ],
146
+ ["within" , False , None ],
147
+ # [True, False, 4],
148
+ # [True, False, 16],
149
+ [True , "default" , None ],
150
+ ],
151
+ )
152
+ @pytest .mark .skipif (
153
+ TORCH_VERSION < version .parse ("2.5.2" ), reason = "requires torch>=2.5"
154
+ )
155
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "no CUDA device found" )
59
156
class TestTo :
60
- def test_to (self , benchmark , consolidated , td , default_device ):
61
- if consolidated :
62
- td = td .consolidate ()
63
- benchmark (lambda : td .to (default_device ))
157
+ def test_to (
158
+ self , benchmark , consolidated , td , default_device , compile_mode , num_threads
159
+ ):
160
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
161
+ pin_mem = default_device .type == "cuda"
162
+ if consolidated is True :
163
+ td = td .consolidate (pin_memory = pin_mem )
164
+
165
+ if consolidated == "within" :
166
+
167
+ def to (td , num_threads ):
168
+ return td .consolidate (pin_memory = pin_mem ).to (
169
+ default_device , num_threads = num_threads
170
+ )
171
+
172
+ else :
173
+
174
+ def to (td , num_threads ):
175
+ return td .to (default_device , num_threads = num_threads )
64
176
65
- def test_to_njt (self , benchmark , consolidated , njt_td , default_device ):
66
- if consolidated :
67
- njt_td = njt_td .consolidate ()
68
- benchmark (lambda : njt_td .to (default_device ))
177
+ if compile_mode :
178
+ to = torch .compile (to , mode = compile_mode , dynamic = True )
179
+
180
+ for _ in range (3 ):
181
+ to (td , num_threads = num_threads )
182
+
183
+ benchmark (to , td , num_threads )
184
+
185
+ def test_to_njt (
186
+ self , benchmark , consolidated , njt_td , default_device , compile_mode , num_threads
187
+ ):
188
+ if compile_mode :
189
+ pytest .skip (
190
+ "Compiling NJTs consolidation currently triggers a RuntimeError."
191
+ )
192
+
193
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
194
+ pin_mem = default_device .type == "cuda"
195
+ if consolidated is True :
196
+ njt_td = njt_td .consolidate (pin_memory = pin_mem )
197
+
198
+ if consolidated == "within" :
199
+
200
+ def to (td , num_threads ):
201
+ return td .consolidate (pin_memory = pin_mem ).to (
202
+ default_device , num_threads = num_threads
203
+ )
204
+
205
+ else :
206
+
207
+ def to (td , num_threads ):
208
+ return td .to (default_device , num_threads = num_threads )
209
+
210
+ if compile_mode :
211
+ to = torch .compile (to , mode = compile_mode , dynamic = True )
212
+
213
+ for _ in range (3 ):
214
+ to (njt_td , num_threads = num_threads )
215
+
216
+ benchmark (to , njt_td , num_threads )
69
217
70
218
71
219
if __name__ == "__main__" :
72
220
args , unknown = argparse .ArgumentParser ().parse_known_args ()
73
- pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
221
+ pytest .main (
222
+ [
223
+ __file__ ,
224
+ "--capture" ,
225
+ "no" ,
226
+ "--exitfirst" ,
227
+ "--benchmark-group-by" ,
228
+ "func" ,
229
+ "-vvv" ,
230
+ ]
231
+ + unknown
232
+ )
0 commit comments