4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import argparse
7
+ import time
8
+ from typing import Any
7
9
8
10
import pytest
9
11
import torch
10
12
from packaging import version
11
13
12
- from tensordict import TensorDict
14
+ from tensordict import tensorclass , TensorDict
15
+ from tensordict .utils import logger as tensordict_logger
13
16
14
17
TORCH_VERSION = version .parse (version .parse (torch .__version__ ).base_version )
15
18
16
19
17
- @pytest .fixture
18
- def td ():
19
- return TensorDict (
20
- {
21
- str (i ): {str (j ): torch .randn (16 , 16 , device = "cpu" ) for j in range (16 )}
22
- for i in range (16 )
23
- },
24
- batch_size = [16 ],
25
- device = "cpu" ,
26
- )
20
+ @tensorclass
21
+ class NJT :
22
+ _values : torch .Tensor
23
+ _offsets : torch .Tensor
24
+ _lengths : torch .Tensor
25
+ njt_shape : Any = None
26
+
27
+ @classmethod
28
+ def from_njt (cls , njt_tensor ):
29
+ return cls (
30
+ _values = njt_tensor ._values ,
31
+ _offsets = njt_tensor ._offsets ,
32
+ _lengths = njt_tensor ._lengths ,
33
+ njt_shape = njt_tensor .size (0 ),
34
+ ).clone ()
35
+
36
+
37
+ @pytest .fixture (autouse = True , scope = "function" )
38
+ def empty_compiler_cache ():
39
+ torch .compiler .reset ()
40
+ yield
27
41
28
42
29
43
def _make_njt ():
@@ -34,14 +48,29 @@ def _make_njt():
34
48
)
35
49
36
50
37
- @pytest .fixture
38
- def njt_td ():
51
+ def _njt_td ():
39
52
return TensorDict (
40
- {str (i ): {str (j ): _make_njt () for j in range (32 )} for i in range (32 )},
53
+ # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
54
+ {str (i ): _make_njt () for i in range (32 )},
41
55
device = "cpu" ,
42
56
)
43
57
44
58
59
+ @pytest .fixture
60
+ def njt_td ():
61
+ return _njt_td ()
62
+
63
+
64
+ @pytest .fixture
65
+ def td ():
66
+ njtd = _njt_td ()
67
+ for k0 , v0 in njtd .items ():
68
+ njtd [k0 ] = NJT .from_njt (v0 )
69
+ # for k1, v1 in v0.items():
70
+ # njtd[k0, k1] = NJT.from_njt(v1)
71
+ return njtd
72
+
73
+
45
74
@pytest .fixture
46
75
def default_device ():
47
76
if torch .cuda .is_available ():
@@ -52,22 +81,147 @@ def default_device():
52
81
pytest .skip ("CUDA/MPS is not available" )
53
82
54
83
55
- @pytest .mark .parametrize ("consolidated" , [False , True ])
84
+ @pytest .mark .parametrize (
85
+ "compile_mode,num_threads" ,
86
+ [
87
+ [False , None ],
88
+ # [False, 4],
89
+ # [False, 16],
90
+ ["default" , None ],
91
+ ["reduce-overhead" , None ],
92
+ ],
93
+ )
94
+ @pytest .mark .skipif (
95
+ TORCH_VERSION < version .parse ("2.5.0" ), reason = "requires torch>=2.5"
96
+ )
97
+ class TestConsolidate :
98
+ def test_consolidate (
99
+ self , benchmark , td , compile_mode , num_threads , default_device
100
+ ):
101
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
102
+
103
+ # td = td.to(default_device)
104
+
105
+ def consolidate (td , num_threads ):
106
+ return td .consolidate (num_threads = num_threads )
107
+
108
+ if compile_mode :
109
+ consolidate = torch .compile (
110
+ consolidate , mode = compile_mode , dynamic = False , fullgraph = True
111
+ )
112
+
113
+ t0 = time .time ()
114
+ consolidate (td , num_threads = num_threads )
115
+ elapsed = time .time () - t0
116
+ tensordict_logger .info (f"elapsed time first call: { elapsed :.2f} sec" )
117
+
118
+ for _ in range (3 ):
119
+ consolidate (td , num_threads = num_threads )
120
+
121
+ benchmark (consolidate , td , num_threads )
122
+
123
+ def test_consolidate_njt (self , benchmark , njt_td , compile_mode , num_threads ):
124
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
125
+
126
+ def consolidate (td , num_threads ):
127
+ return td .consolidate (num_threads = num_threads )
128
+
129
+ if compile_mode :
130
+ pytest .skip (
131
+ "Compiling NJTs consolidation currently triggers a RuntimeError."
132
+ )
133
+ # consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)
134
+
135
+ for _ in range (3 ):
136
+ consolidate (njt_td , num_threads = num_threads )
137
+
138
+ benchmark (consolidate , njt_td , num_threads )
139
+
140
+
141
+ @pytest .mark .parametrize (
142
+ "consolidated,compile_mode,num_threads" ,
143
+ [
144
+ [False , False , None ],
145
+ [True , False , None ],
146
+ ["within" , False , None ],
147
+ # [True, False, 4],
148
+ # [True, False, 16],
149
+ [True , "default" , None ],
150
+ ],
151
+ )
56
152
@pytest .mark .skipif (
57
- TORCH_VERSION < version .parse ("2.5.1 " ), reason = "requires torch>=2.5"
153
+ TORCH_VERSION < version .parse ("2.5.2 " ), reason = "requires torch>=2.5"
58
154
)
155
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "no CUDA device found" )
59
156
class TestTo :
60
- def test_to (self , benchmark , consolidated , td , default_device ):
61
- if consolidated :
62
- td = td .consolidate ()
63
- benchmark (lambda : td .to (default_device ))
157
+ def test_to (
158
+ self , benchmark , consolidated , td , default_device , compile_mode , num_threads
159
+ ):
160
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
161
+ pin_mem = default_device .type == "cuda"
162
+ if consolidated is True :
163
+ td = td .consolidate (pin_memory = pin_mem )
164
+
165
+ if consolidated == "within" :
166
+
167
+ def to (td , num_threads ):
168
+ return td .consolidate (pin_memory = pin_mem ).to (
169
+ default_device , num_threads = num_threads
170
+ )
171
+
172
+ else :
64
173
65
- def test_to_njt (self , benchmark , consolidated , njt_td , default_device ):
66
- if consolidated :
67
- njt_td = njt_td .consolidate ()
68
- benchmark (lambda : njt_td .to (default_device ))
174
+ def to (td , num_threads ):
175
+ return td .to (default_device , num_threads = num_threads )
176
+
177
+ if compile_mode :
178
+ to = torch .compile (to , mode = compile_mode , dynamic = True )
179
+
180
+ for _ in range (3 ):
181
+ to (td , num_threads = num_threads )
182
+
183
+ benchmark (to , td , num_threads )
184
+
185
+ def test_to_njt (
186
+ self , benchmark , consolidated , njt_td , default_device , compile_mode , num_threads
187
+ ):
188
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
189
+ pin_mem = default_device .type == "cuda"
190
+ if consolidated is True :
191
+ njt_td = njt_td .consolidate (pin_memory = pin_mem )
192
+
193
+ if consolidated == "within" :
194
+
195
+ def to (td , num_threads ):
196
+ return td .consolidate (pin_memory = pin_mem ).to (
197
+ default_device , num_threads = num_threads
198
+ )
199
+
200
+ else :
201
+
202
+ def to (td , num_threads ):
203
+ return td .to (default_device , num_threads = num_threads )
204
+
205
+ if compile_mode :
206
+ to = torch .compile (to , mode = compile_mode , dynamic = True )
207
+
208
+ for _ in range (3 ):
209
+ to (njt_td , num_threads = num_threads )
210
+
211
+ benchmark (to , njt_td , num_threads )
69
212
70
213
71
214
if __name__ == "__main__" :
72
215
args , unknown = argparse .ArgumentParser ().parse_known_args ()
73
- pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
216
+ pytest .main (
217
+ [
218
+ __file__ ,
219
+ "--capture" ,
220
+ "no" ,
221
+ "--exitfirst" ,
222
+ "--benchmark-group-by" ,
223
+ "func" ,
224
+ "-vvv" ,
225
+ ]
226
+ + unknown
227
+ )
0 commit comments