Skip to content

Commit 0e7be29

Browse files
author
Vincent Moens
committed
[Feature] compatibility of consolidate with compile (quick version)
ghstack-source-id: 1bf3ca5 Pull Request resolved: #1061
1 parent 752e6dc commit 0e7be29

File tree

5 files changed

+223
-38
lines changed

5 files changed

+223
-38
lines changed

benchmarks/common/h2d_test.py

Lines changed: 184 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,40 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
import time
8+
from typing import Any
79

810
import pytest
911
import torch
1012
from packaging import version
1113

12-
from tensordict import TensorDict
14+
from tensordict import tensorclass, TensorDict
15+
from tensordict.utils import logger as tensordict_logger
1316

1417
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1518

1619

17-
@pytest.fixture
18-
def td():
19-
return TensorDict(
20-
{
21-
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
22-
for i in range(16)
23-
},
24-
batch_size=[16],
25-
device="cpu",
26-
)
20+
@tensorclass
21+
class NJT:
22+
_values: torch.Tensor
23+
_offsets: torch.Tensor
24+
_lengths: torch.Tensor
25+
njt_shape: Any = None
26+
27+
@classmethod
28+
def from_njt(cls, njt_tensor):
29+
return cls(
30+
_values=njt_tensor._values,
31+
_offsets=njt_tensor._offsets,
32+
_lengths=njt_tensor._lengths,
33+
njt_shape=njt_tensor.size(0),
34+
).clone()
35+
36+
37+
@pytest.fixture(autouse=True, scope="function")
38+
def empty_compiler_cache():
39+
torch.compiler.reset()
40+
yield
2741

2842

2943
def _make_njt():
@@ -34,14 +48,29 @@ def _make_njt():
3448
)
3549

3650

37-
@pytest.fixture
38-
def njt_td():
51+
def _njt_td():
3952
return TensorDict(
40-
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
53+
# {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
54+
{str(i): _make_njt() for i in range(32)},
4155
device="cpu",
4256
)
4357

4458

59+
@pytest.fixture
60+
def njt_td():
61+
return _njt_td()
62+
63+
64+
@pytest.fixture
65+
def td():
66+
njtd = _njt_td()
67+
for k0, v0 in njtd.items():
68+
njtd[k0] = NJT.from_njt(v0)
69+
# for k1, v1 in v0.items():
70+
# njtd[k0, k1] = NJT.from_njt(v1)
71+
return njtd
72+
73+
4574
@pytest.fixture
4675
def default_device():
4776
if torch.cuda.is_available():
@@ -52,22 +81,152 @@ def default_device():
5281
pytest.skip("CUDA/MPS is not available")
5382

5483

55-
@pytest.mark.parametrize("consolidated", [False, True])
84+
@pytest.mark.parametrize(
85+
"compile_mode,num_threads",
86+
[
87+
[False, None],
88+
# [False, 4],
89+
# [False, 16],
90+
["default", None],
91+
["reduce-overhead", None],
92+
],
93+
)
5694
@pytest.mark.skipif(
57-
TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5"
95+
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
5896
)
97+
class TestConsolidate:
98+
def test_consolidate(
99+
self, benchmark, td, compile_mode, num_threads, default_device
100+
):
101+
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
102+
103+
# td = td.to(default_device)
104+
105+
def consolidate(td, num_threads):
106+
return td.consolidate(num_threads=num_threads)
107+
108+
if compile_mode:
109+
consolidate = torch.compile(
110+
consolidate, mode=compile_mode, dynamic=False, fullgraph=True
111+
)
112+
113+
t0 = time.time()
114+
consolidate(td, num_threads=num_threads)
115+
elapsed = time.time() - t0
116+
tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec")
117+
118+
for _ in range(3):
119+
consolidate(td, num_threads=num_threads)
120+
121+
benchmark(consolidate, td, num_threads)
122+
123+
def test_consolidate_njt(self, benchmark, njt_td, compile_mode, num_threads):
124+
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
125+
126+
def consolidate(td, num_threads):
127+
return td.consolidate(num_threads=num_threads)
128+
129+
if compile_mode:
130+
pytest.skip(
131+
"Compiling NJTs consolidation currently triggers a RuntimeError."
132+
)
133+
# consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)
134+
135+
for _ in range(3):
136+
consolidate(njt_td, num_threads=num_threads)
137+
138+
benchmark(consolidate, njt_td, num_threads)
139+
140+
141+
@pytest.mark.parametrize(
142+
"consolidated,compile_mode,num_threads",
143+
[
144+
[False, False, None],
145+
[True, False, None],
146+
["within", False, None],
147+
# [True, False, 4],
148+
# [True, False, 16],
149+
[True, "default", None],
150+
],
151+
)
152+
@pytest.mark.skipif(
153+
TORCH_VERSION < version.parse("2.5.2"), reason="requires torch>=2.5"
154+
)
155+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA device found")
59156
class TestTo:
60-
def test_to(self, benchmark, consolidated, td, default_device):
61-
if consolidated:
62-
td = td.consolidate()
63-
benchmark(lambda: td.to(default_device))
157+
def test_to(
158+
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
159+
):
160+
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
161+
pin_mem = default_device.type == "cuda"
162+
if consolidated is True:
163+
td = td.consolidate(pin_memory=pin_mem)
164+
165+
if consolidated == "within":
166+
167+
def to(td, num_threads):
168+
return td.consolidate(pin_memory=pin_mem).to(
169+
default_device, num_threads=num_threads
170+
)
171+
172+
else:
173+
174+
def to(td, num_threads):
175+
return td.to(default_device, num_threads=num_threads)
64176

65-
def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
66-
if consolidated:
67-
njt_td = njt_td.consolidate()
68-
benchmark(lambda: njt_td.to(default_device))
177+
if compile_mode:
178+
to = torch.compile(to, mode=compile_mode, dynamic=True)
179+
180+
for _ in range(3):
181+
to(td, num_threads=num_threads)
182+
183+
benchmark(to, td, num_threads)
184+
185+
def test_to_njt(
186+
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
187+
):
188+
if compile_mode:
189+
pytest.skip(
190+
"Compiling NJTs consolidation currently triggers a RuntimeError."
191+
)
192+
193+
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
194+
pin_mem = default_device.type == "cuda"
195+
if consolidated is True:
196+
njt_td = njt_td.consolidate(pin_memory=pin_mem)
197+
198+
if consolidated == "within":
199+
200+
def to(td, num_threads):
201+
return td.consolidate(pin_memory=pin_mem).to(
202+
default_device, num_threads=num_threads
203+
)
204+
205+
else:
206+
207+
def to(td, num_threads):
208+
return td.to(default_device, num_threads=num_threads)
209+
210+
if compile_mode:
211+
to = torch.compile(to, mode=compile_mode, dynamic=True)
212+
213+
for _ in range(3):
214+
to(njt_td, num_threads=num_threads)
215+
216+
benchmark(to, njt_td, num_threads)
69217

70218

71219
if __name__ == "__main__":
72220
args, unknown = argparse.ArgumentParser().parse_known_args()
73-
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
221+
pytest.main(
222+
[
223+
__file__,
224+
"--capture",
225+
"no",
226+
"--exitfirst",
227+
"--benchmark-group-by",
228+
"func",
229+
"-vvv",
230+
]
231+
+ unknown
232+
)

benchmarks/compile/tensordict_nn_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121

2222
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2323

24-
torch.set_default_device(DEVICE)
24+
25+
@pytest.fixture(scope="function", autouse=True)
26+
def auto_device():
27+
device = torch.get_default_device()
28+
torch.set_default_device(DEVICE)
29+
yield
30+
torch.set_default_device(device)
31+
2532

2633
compile = functools.partial(torch.compile, fullgraph=True)
2734
compile_overhead = functools.partial(
@@ -32,7 +39,10 @@
3239
@pytest.fixture(scope="function", autouse=True)
3340
def reset_dynamo():
3441
# Start a fresh compile for each parameter of the test case
35-
torch._dynamo.reset()
42+
try:
43+
torch.compiler.reset()
44+
except AttributeError:
45+
torch._dynamo.reset()
3646
gc.collect()
3747
yield
3848

tensordict/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3860,8 +3860,9 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
38603860
pad = 8 - pad
38613861
else:
38623862
pad = 0
3863-
flat_size.append(n + pad)
3864-
stop = start + flat_size[-1]
3863+
flat_size.append(sum([n, pad]))
3864+
# Using sum to tell dynamo to use sym_sum
3865+
stop = sum([start, flat_size[-1]])
38653866
if requires_metadata:
38663867
metadata_dict["leaves"][key] = (
38673868
_DTYPE2STRDTYPE[dtype],
@@ -4136,6 +4137,8 @@ def view_old_as_new(v, oldv):
41364137
return v[: oldv.numel()].view(oldv.shape)
41374138
return v.view(oldv.shape)
41384139

4140+
if num_threads is None:
4141+
num_threads = 0
41394142
if num_threads > 0:
41404143

41414144
def assign(
@@ -4241,7 +4244,10 @@ def _view_and_pad(tensor):
42414244
if v.device != storage.device:
42424245
v = v.to(storage.device, non_blocking=non_blocking)
42434246
stride = v.stride()
4244-
if (stride and stride[-1] != 1) or v.storage_offset():
4247+
if is_dynamo_compiling():
4248+
if not v.is_contiguous():
4249+
v = v.clone(memory_format=torch.contiguous_format)
4250+
elif (stride and stride[-1] != 1) or v.storage_offset():
42454251
v = v.clone(memory_format=torch.contiguous_format)
42464252
v, pad = _view_and_pad(v)
42474253
items.append(v)

tensordict/tensorclass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __subclasscheck__(self, subclass):
137137
"_multithread_rebuild", # rebuild checks if self is a non tensor
138138
"_propagate_lock",
139139
"_propagate_unlock",
140+
"_reduce_get_metadata",
140141
"_values_list",
141142
"data_ptr",
142143
"dim",

test/test_tensordict.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6380,27 +6380,36 @@ def test_to_device_dtype_inplace(self, td_name, device):
63806380
td = getattr(self, td_name)(device)
63816381
if torch.cuda.is_available():
63826382
dest = torch.device("cuda:0")
6383-
elif torch.mps.is_available():
6384-
dest = torch.device("mps:0")
6383+
# elif torch.mps.is_available():
6384+
# dest = torch.device("mps:0")
63856385
else:
63866386
dest = torch.device("cpu")
63876387

63886388
if td_name in ("sub_td", "sub_td2"):
6389-
cm = pytest.raises(
6389+
cm_device = cm_dtype = pytest.raises(
63906390
TypeError,
63916391
match="Cannot send a _SubTensorDict instance to device/dtype inplace",
63926392
)
63936393
elif td_name in ("permute_td", "unsqueezed_td", "squeezed_td", "td_h5"):
6394-
cm = pytest.raises(TypeError, match="Cannot use inplace=True with")
6394+
cm_device = cm_dtype = pytest.raises(
6395+
TypeError, match="Cannot use inplace=True with"
6396+
)
6397+
elif td_name in ("memmap_td",) and dest.type == "cpu":
6398+
cm_device = contextlib.nullcontext()
6399+
cm_dtype = pytest.raises(
6400+
RuntimeError, match="Cannot modify locked TensorDict."
6401+
)
63956402
elif td.is_locked:
6396-
cm = pytest.raises(RuntimeError, match="Cannot modify locked TensorDict.")
6403+
cm_device = cm_dtype = pytest.raises(
6404+
RuntimeError, match="Cannot modify locked TensorDict."
6405+
)
63976406
else:
6398-
cm = contextlib.nullcontext()
6399-
with cm:
6407+
cm_device = cm_dtype = contextlib.nullcontext()
6408+
with cm_dtype:
64006409
td.to(torch.float32, inplace=True)
64016410
assert td.dtype == torch.float32, td
64026411

6403-
with cm:
6412+
with cm_device:
64046413
td.to(dest, inplace=True)
64056414
assert td.device == dest
64066415
for v in td.values(

0 commit comments

Comments
 (0)