4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import argparse
7
+ import time
8
+ from typing import Any
7
9
8
10
import pytest
9
11
import torch
10
12
from packaging import version
11
13
12
- from tensordict import TensorDict
14
+ from tensordict import tensorclass , TensorDict
15
+ from tensordict .utils import logger as tensordict_logger
13
16
14
17
TORCH_VERSION = version .parse (version .parse (torch .__version__ ).base_version )
15
18
16
19
17
- @pytest .fixture
18
- def td ():
19
- return TensorDict (
20
- {
21
- str (i ): {str (j ): torch .randn (16 , 16 , device = "cpu" ) for j in range (16 )}
22
- for i in range (16 )
23
- },
24
- batch_size = [16 ],
25
- device = "cpu" ,
26
- )
20
+ @tensorclass
21
+ class NJT :
22
+ _values : torch .Tensor
23
+ _offsets : torch .Tensor
24
+ _lengths : torch .Tensor
25
+ njt_shape : Any = None
26
+
27
+ @classmethod
28
+ def from_njt (cls , njt_tensor ):
29
+ return cls (
30
+ _values = njt_tensor ._values ,
31
+ _offsets = njt_tensor ._offsets ,
32
+ _lengths = njt_tensor ._lengths ,
33
+ njt_shape = njt_tensor .size (0 ),
34
+ ).clone ()
35
+
36
+
37
+ @pytest .fixture (autouse = True , scope = "function" )
38
+ def empty_compiler_cache ():
39
+ torch .compiler .reset ()
40
+ yield
27
41
28
42
29
43
def _make_njt ():
@@ -34,14 +48,29 @@ def _make_njt():
34
48
)
35
49
36
50
37
- @pytest .fixture
38
- def njt_td ():
51
+ def _njt_td ():
39
52
return TensorDict (
40
- {str (i ): {str (j ): _make_njt () for j in range (32 )} for i in range (32 )},
53
+ # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
54
+ {str (i ): _make_njt () for i in range (128 )},
41
55
device = "cpu" ,
42
56
)
43
57
44
58
59
+ @pytest .fixture
60
+ def njt_td ():
61
+ return _njt_td ()
62
+
63
+
64
+ @pytest .fixture
65
+ def td ():
66
+ njtd = _njt_td ()
67
+ for k0 , v0 in njtd .items ():
68
+ njtd [k0 ] = NJT .from_njt (v0 )
69
+ # for k1, v1 in v0.items():
70
+ # njtd[k0, k1] = NJT.from_njt(v1)
71
+ return njtd
72
+
73
+
45
74
@pytest .fixture
46
75
def default_device ():
47
76
if torch .cuda .is_available ():
@@ -52,22 +81,139 @@ def default_device():
52
81
pytest .skip ("CUDA/MPS is not available" )
53
82
54
83
55
- @pytest .mark .parametrize ("consolidated" , [False , True ])
84
+ @pytest .mark .parametrize (
85
+ "compile_mode,num_threads" ,
86
+ [
87
+ [False , None ],
88
+ # [False, 4],
89
+ # [False, 16],
90
+ ["default" , None ],
91
+ ["reduce-overhead" , None ],
92
+ ],
93
+ )
94
+ @pytest .mark .skipif (
95
+ TORCH_VERSION < version .parse ("2.5.0" ), reason = "requires torch>=2.5"
96
+ )
97
+ class TestConsolidate :
98
+ def test_consolidate (self , benchmark , td , compile_mode , num_threads ):
99
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
100
+
101
+ def consolidate (td , num_threads ):
102
+ return td .consolidate (num_threads = num_threads )
103
+
104
+ if compile_mode :
105
+ consolidate = torch .compile (
106
+ consolidate , mode = compile_mode , dynamic = True , fullgraph = True
107
+ )
108
+
109
+ t0 = time .time ()
110
+ consolidate (td , num_threads = num_threads )
111
+ elapsed = time .time () - t0
112
+ tensordict_logger .info (f"elapsed time first call: { elapsed :.2f} sec" )
113
+
114
+ for _ in range (3 ):
115
+ consolidate (td , num_threads = num_threads )
116
+
117
+ benchmark (consolidate , td , num_threads )
118
+
119
+ def test_to_njt (self , benchmark , njt_td , compile_mode , num_threads ):
120
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
121
+
122
+ def consolidate (td , num_threads ):
123
+ return td .consolidate (num_threads = num_threads )
124
+
125
+ if compile_mode :
126
+ consolidate = torch .compile (consolidate , mode = compile_mode , dynamic = True )
127
+
128
+ for _ in range (3 ):
129
+ consolidate (njt_td , num_threads = num_threads )
130
+
131
+ benchmark (consolidate , njt_td , num_threads )
132
+
133
+
134
+ @pytest .mark .parametrize (
135
+ "consolidated,compile_mode,num_threads" ,
136
+ [
137
+ [False , False , None ],
138
+ [True , False , None ],
139
+ ["within" , False , None ],
140
+ # [True, False, 4],
141
+ # [True, False, 16],
142
+ [True , "default" , None ],
143
+ ],
144
+ )
56
145
@pytest .mark .skipif (
57
146
TORCH_VERSION < version .parse ("2.5.1" ), reason = "requires torch>=2.5"
58
147
)
59
148
class TestTo :
60
- def test_to (self , benchmark , consolidated , td , default_device ):
61
- if consolidated :
62
- td = td .consolidate ()
63
- benchmark (lambda : td .to (default_device ))
149
+ def test_to (
150
+ self , benchmark , consolidated , td , default_device , compile_mode , num_threads
151
+ ):
152
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
153
+ pin_mem = default_device .type == "cuda"
154
+ if consolidated is True :
155
+ td = td .consolidate (pin_memory = pin_mem )
156
+
157
+ if consolidated == "within" :
158
+
159
+ def to (td , num_threads ):
160
+ return td .consolidate (pin_memory = pin_mem ).to (
161
+ default_device , num_threads = num_threads
162
+ )
163
+
164
+ else :
64
165
65
- def test_to_njt (self , benchmark , consolidated , njt_td , default_device ):
66
- if consolidated :
67
- njt_td = njt_td .consolidate ()
68
- benchmark (lambda : njt_td .to (default_device ))
166
+ def to (td , num_threads ):
167
+ return td .to (default_device , num_threads = num_threads )
168
+
169
+ if compile_mode :
170
+ to = torch .compile (to , mode = compile_mode , dynamic = True )
171
+
172
+ for _ in range (3 ):
173
+ to (td , num_threads = num_threads )
174
+
175
+ benchmark (to , td , num_threads )
176
+
177
+ def test_to_njt (
178
+ self , benchmark , consolidated , njt_td , default_device , compile_mode , num_threads
179
+ ):
180
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
181
+ pin_mem = default_device .type == "cuda"
182
+ if consolidated is True :
183
+ njt_td = njt_td .consolidate (pin_memory = pin_mem )
184
+
185
+ if consolidated == "within" :
186
+
187
+ def to (td , num_threads ):
188
+ return td .consolidate (pin_memory = pin_mem ).to (
189
+ default_device , num_threads = num_threads
190
+ )
191
+
192
+ else :
193
+
194
+ def to (td , num_threads ):
195
+ return td .to (default_device , num_threads = num_threads )
196
+
197
+ if compile_mode :
198
+ to = torch .compile (to , mode = compile_mode , dynamic = True )
199
+
200
+ for _ in range (3 ):
201
+ to (njt_td , num_threads = num_threads )
202
+
203
+ benchmark (to , njt_td , num_threads )
69
204
70
205
71
206
if __name__ == "__main__" :
72
207
args , unknown = argparse .ArgumentParser ().parse_known_args ()
73
- pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
208
+ pytest .main (
209
+ [
210
+ __file__ ,
211
+ "--capture" ,
212
+ "no" ,
213
+ "--exitfirst" ,
214
+ "--benchmark-group-by" ,
215
+ "func" ,
216
+ "-vvv" ,
217
+ ]
218
+ + unknown
219
+ )
0 commit comments