Skip to content

Commit a7cb2bb

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: 55de7c9 Pull Request resolved: #1041
1 parent 8c65dcb commit a7cb2bb

File tree

8 files changed

+611
-485
lines changed

8 files changed

+611
-485
lines changed

benchmarks/common/h2d_test.py

Lines changed: 170 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,40 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
import time
8+
from typing import Any
79

810
import pytest
911
import torch
1012
from packaging import version
1113

12-
from tensordict import TensorDict
14+
from tensordict import tensorclass, TensorDict
15+
from tensordict.utils import logger as tensordict_logger
1316

1417
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1518

1619

17-
@pytest.fixture
18-
def td():
19-
return TensorDict(
20-
{
21-
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
22-
for i in range(16)
23-
},
24-
batch_size=[16],
25-
device="cpu",
26-
)
20+
@tensorclass
21+
class NJT:
22+
_values: torch.Tensor
23+
_offsets: torch.Tensor
24+
_lengths: torch.Tensor
25+
njt_shape: Any = None
26+
27+
@classmethod
28+
def from_njt(cls, njt_tensor):
29+
return cls(
30+
_values=njt_tensor._values,
31+
_offsets=njt_tensor._offsets,
32+
_lengths=njt_tensor._lengths,
33+
njt_shape=njt_tensor.size(0),
34+
).clone()
35+
36+
37+
@pytest.fixture(autouse=True, scope="function")
38+
def empty_compiler_cache():
39+
torch.compiler.reset()
40+
yield
2741

2842

2943
def _make_njt():
@@ -34,14 +48,29 @@ def _make_njt():
3448
)
3549

3650

37-
@pytest.fixture
38-
def njt_td():
51+
def _njt_td():
3952
return TensorDict(
40-
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
53+
# {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
54+
{str(i): _make_njt() for i in range(128)},
4155
device="cpu",
4256
)
4357

4458

59+
@pytest.fixture
60+
def njt_td():
61+
return _njt_td()
62+
63+
64+
@pytest.fixture
65+
def td():
66+
njtd = _njt_td()
67+
for k0, v0 in njtd.items():
68+
njtd[k0] = NJT.from_njt(v0)
69+
# for k1, v1 in v0.items():
70+
# njtd[k0, k1] = NJT.from_njt(v1)
71+
return njtd
72+
73+
4574
@pytest.fixture
4675
def default_device():
4776
if torch.cuda.is_available():
@@ -52,22 +81,139 @@ def default_device():
5281
pytest.skip("CUDA/MPS is not available")
5382

5483

55-
@pytest.mark.parametrize("consolidated", [False, True])
84+
@pytest.mark.parametrize(
85+
"compile_mode,num_threads",
86+
[
87+
[False, None],
88+
# [False, 4],
89+
# [False, 16],
90+
["default", None],
91+
["reduce-overhead", None],
92+
],
93+
)
94+
@pytest.mark.skipif(
95+
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
96+
)
97+
class TestConsolidate:
98+
def test_consolidate(self, benchmark, td, compile_mode, num_threads):
99+
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
100+
101+
def consolidate(td, num_threads):
102+
return td.consolidate(num_threads=num_threads)
103+
104+
if compile_mode:
105+
consolidate = torch.compile(
106+
consolidate, mode=compile_mode, dynamic=True, fullgraph=True
107+
)
108+
109+
t0 = time.time()
110+
consolidate(td, num_threads=num_threads)
111+
elapsed = time.time() - t0
112+
tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec")
113+
114+
for _ in range(3):
115+
consolidate(td, num_threads=num_threads)
116+
117+
benchmark(consolidate, td, num_threads)
118+
119+
def test_to_njt(self, benchmark, njt_td, compile_mode, num_threads):
120+
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
121+
122+
def consolidate(td, num_threads):
123+
return td.consolidate(num_threads=num_threads)
124+
125+
if compile_mode:
126+
consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)
127+
128+
for _ in range(3):
129+
consolidate(njt_td, num_threads=num_threads)
130+
131+
benchmark(consolidate, njt_td, num_threads)
132+
133+
134+
@pytest.mark.parametrize(
135+
"consolidated,compile_mode,num_threads",
136+
[
137+
[False, False, None],
138+
[True, False, None],
139+
["within", False, None],
140+
# [True, False, 4],
141+
# [True, False, 16],
142+
[True, "default", None],
143+
],
144+
)
56145
@pytest.mark.skipif(
57146
TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5"
58147
)
59148
class TestTo:
60-
def test_to(self, benchmark, consolidated, td, default_device):
61-
if consolidated:
62-
td = td.consolidate()
63-
benchmark(lambda: td.to(default_device))
149+
def test_to(
150+
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
151+
):
152+
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
153+
pin_mem = default_device.type == "cuda"
154+
if consolidated is True:
155+
td = td.consolidate(pin_memory=pin_mem)
156+
157+
if consolidated == "within":
158+
159+
def to(td, num_threads):
160+
return td.consolidate(pin_memory=pin_mem).to(
161+
default_device, num_threads=num_threads
162+
)
163+
164+
else:
64165

65-
def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
66-
if consolidated:
67-
njt_td = njt_td.consolidate()
68-
benchmark(lambda: njt_td.to(default_device))
166+
def to(td, num_threads):
167+
return td.to(default_device, num_threads=num_threads)
168+
169+
if compile_mode:
170+
to = torch.compile(to, mode=compile_mode, dynamic=True)
171+
172+
for _ in range(3):
173+
to(td, num_threads=num_threads)
174+
175+
benchmark(to, td, num_threads)
176+
177+
def test_to_njt(
178+
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
179+
):
180+
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
181+
pin_mem = default_device.type == "cuda"
182+
if consolidated is True:
183+
njt_td = njt_td.consolidate(pin_memory=pin_mem)
184+
185+
if consolidated == "within":
186+
187+
def to(td, num_threads):
188+
return td.consolidate(pin_memory=pin_mem).to(
189+
default_device, num_threads=num_threads
190+
)
191+
192+
else:
193+
194+
def to(td, num_threads):
195+
return td.to(default_device, num_threads=num_threads)
196+
197+
if compile_mode:
198+
to = torch.compile(to, mode=compile_mode, dynamic=True)
199+
200+
for _ in range(3):
201+
to(njt_td, num_threads=num_threads)
202+
203+
benchmark(to, njt_td, num_threads)
69204

70205

71206
if __name__ == "__main__":
72207
args, unknown = argparse.ArgumentParser().parse_known_args()
73-
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
208+
pytest.main(
209+
[
210+
__file__,
211+
"--capture",
212+
"no",
213+
"--exitfirst",
214+
"--benchmark-group-by",
215+
"func",
216+
"-vvv",
217+
]
218+
+ unknown
219+
)

benchmarks/compile/compile_td_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class MyTensorClass:
2323
f: torch.Tensor
2424

2525

26+
@pytest.fixture(autouse=True, scope="function")
27+
def empty_compiler_cache():
28+
torch._dynamo.reset_code_caches()
29+
yield
30+
31+
2632
# Functions
2733
def add_one(td):
2834
return td + 1

tensordict/_reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _make_td(cls, state):
138138

139139
def _reduce_td(data: TensorDict):
140140
consolidated = getattr(data, "_consolidated", None)
141-
if consolidated and consolidated["metadata"] is not None:
141+
if isinstance(consolidated, dict):
142142
storage = consolidated["storage"]
143143
storge_metadata = consolidated["metadata"]
144144
return (

tensordict/_td.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4210,7 +4210,7 @@ def _iter():
42104210
if self.leaves_only:
42114211
for key in self._keys():
42124212
target_class = self.tensordict.entry_class(key)
4213-
if _is_tensor_collection(target_class):
4213+
if self.is_leaf(target_class):
42144214
continue
42154215
yield key
42164216
else:
@@ -4239,9 +4239,10 @@ def _iter_helper(
42394239
# For lazy stacks
42404240
value = value[0]
42414241
cls = type(value)
4242-
is_leaf = self.is_leaf(cls)
4243-
if self.include_nested and not is_leaf:
4242+
is_tc = _is_tensor_collection(cls)
4243+
if self.include_nested and is_tc:
42444244
yield from self._iter_helper(value, prefix=full_key)
4245+
is_leaf = self.is_leaf(cls)
42454246
if not self.leaves_only or is_leaf:
42464247
yield full_key
42474248

0 commit comments

Comments
 (0)