From b20468c52f7fb652ea16e3c2257d6e0af8d7798b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 25 Oct 2024 10:42:40 -0700 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 197 +++++++++++++++++++++++++++++----- tensordict/base.py | 12 ++- tensordict/tensorclass.py | 1 + 3 files changed, 183 insertions(+), 27 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index b08298dc1..227aa8106 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -4,26 +4,40 @@ # LICENSE file in the root directory of this source tree. import argparse +import time +from typing import Any import pytest import torch from packaging import version -from tensordict import TensorDict +from tensordict import tensorclass, TensorDict +from tensordict.utils import logger as tensordict_logger TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) -@pytest.fixture -def td(): - return TensorDict( - { - str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)} - for i in range(16) - }, - batch_size=[16], - device="cpu", - ) +@tensorclass +class NJT: + _values: torch.Tensor + _offsets: torch.Tensor + _lengths: torch.Tensor + njt_shape: Any = None + + @classmethod + def from_njt(cls, njt_tensor): + return cls( + _values=njt_tensor._values, + _offsets=njt_tensor._offsets, + _lengths=njt_tensor._lengths, + njt_shape=njt_tensor.size(0), + ) + + +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch.compiler.reset() + yield def _make_njt(): @@ -34,14 +48,29 @@ def _make_njt(): ) -@pytest.fixture -def njt_td(): +def _njt_td(): return TensorDict( - {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + {str(i): _make_njt() for i in range(8)}, device="cpu", ) +@pytest.fixture +def njt_td(): + return _njt_td() + + +@pytest.fixture +def td(): + njtd = _njt_td() + for k0, v0 in njtd.items(): + njtd[k0] = NJT.from_njt(v0) + # for k1, v1 in v0.items(): + # njtd[k0, k1] = NJT.from_njt(v1) + return njtd + + @pytest.fixture def default_device(): if torch.cuda.is_available(): @@ -52,22 +81,142 @@ def default_device(): pytest.skip("CUDA/MPS is not available") -@pytest.mark.parametrize("consolidated", [False, True]) +@pytest.mark.parametrize( + "compile_mode,num_threads", + [ + [False, None], + # [False, 4], + # [False, 16], + ["default", None], + ["reduce-overhead", None], + ], +) +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" +) +class TestConsolidate: + def test_consolidate(self, benchmark, td, compile_mode, num_threads): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + consolidate = torch.compile( + consolidate, mode=compile_mode, dynamic=True, fullgraph=True + ) + + t0 = time.time() + consolidate(td, num_threads=num_threads) + elapsed = time.time() - t0 + tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec") + + for _ in range(3): + consolidate(td, num_threads=num_threads) + + benchmark(consolidate, td, num_threads) + + def test_consolidate_njt(self, benchmark, njt_td, compile_mode, num_threads): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + pytest.skip( + "Compiling NJTs consolidation currently triggers a RuntimeError." + ) + # consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True) + + for _ in range(3): + consolidate(njt_td, num_threads=num_threads) + + benchmark(consolidate, njt_td, num_threads) + + +@pytest.mark.parametrize( + "consolidated,compile_mode,num_threads", + [ + [False, False, None], + [True, False, None], + ["within", False, None], + # [True, False, 4], + # [True, False, 16], + [True, "default", None], + ], +) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5" ) class TestTo: - def test_to(self, benchmark, consolidated, td, default_device): - if consolidated: - td = td.consolidate() - benchmark(lambda: td.to(default_device)) + def test_to( + self, benchmark, consolidated, td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + td = td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: - def test_to_njt(self, benchmark, consolidated, njt_td, default_device): - if consolidated: - njt_td = njt_td.consolidate() - benchmark(lambda: njt_td.to(default_device)) + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode, dynamic=True) + + for _ in range(3): + to(td, num_threads=num_threads) + + benchmark(to, td, num_threads) + + def test_to_njt( + self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + njt_td = njt_td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: + + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode, dynamic=True) + + for _ in range(3): + to(njt_td, num_threads=num_threads) + + benchmark(to, njt_td, num_threads) if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main( + [ + __file__, + "--capture", + "no", + "--exitfirst", + "--benchmark-group-by", + "func", + "-vvv", + ] + + unknown + ) diff --git a/tensordict/base.py b/tensordict/base.py index 79ad2cfaf..0053b2634 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3860,8 +3860,9 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): pad = 8 - pad else: pad = 0 - flat_size.append(n + pad) - stop = start + flat_size[-1] + flat_size.append(sum([n, pad])) + # Using sum to tell dynamo to use sym_sum + stop = sum([start, flat_size[-1]]) if requires_metadata: metadata_dict["leaves"][key] = ( _DTYPE2STRDTYPE[dtype], @@ -4136,6 +4137,8 @@ def view_old_as_new(v, oldv): return v[: oldv.numel()].view(oldv.shape) return v.view(oldv.shape) + if num_threads is None: + num_threads = 0 if num_threads > 0: def assign( @@ -4241,7 +4244,10 @@ def _view_and_pad(tensor): if v.device != storage.device: v = v.to(storage.device, non_blocking=non_blocking) stride = v.stride() - if (stride and stride[-1] != 1) or v.storage_offset(): + if is_dynamo_compiling(): + if not v.is_contiguous(): + v = v.clone(memory_format=torch.contiguous_format) + elif (stride and stride[-1] != 1) or v.storage_offset(): v = v.clone(memory_format=torch.contiguous_format) v, pad = _view_and_pad(v) items.append(v) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 37abe37c1..6f6f9cc13 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass): "_multithread_rebuild", # rebuild checks if self is a non tensor "_propagate_lock", "_propagate_unlock", + "_reduce_get_metadata", "_values_list", "data_ptr", "dim", From 68b69f0947cb94e90aa1bcb9acbd5b4dd0ea4efc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 29 Oct 2024 09:38:46 +0000 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 227aa8106..6e5eb509b 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -31,7 +31,7 @@ def from_njt(cls, njt_tensor): _offsets=njt_tensor._offsets, _lengths=njt_tensor._lengths, njt_shape=njt_tensor.size(0), - ) + ).clone() @pytest.fixture(autouse=True, scope="function") @@ -148,6 +148,7 @@ def consolidate(td, num_threads): @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5" ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA device found") class TestTo: def test_to( self, benchmark, consolidated, td, default_device, compile_mode, num_threads From 02aa6ed1c1e2a4b730edd8eb494efa8892cf2f4d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 31 Oct 2024 11:11:51 +0000 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 10 +++++++--- benchmarks/compile/tensordict_nn_test.py | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 6e5eb509b..059601fb4 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -51,7 +51,7 @@ def _make_njt(): def _njt_td(): return TensorDict( # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, - {str(i): _make_njt() for i in range(8)}, + {str(i): _make_njt() for i in range(32)}, device="cpu", ) @@ -95,15 +95,19 @@ def default_device(): TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) class TestConsolidate: - def test_consolidate(self, benchmark, td, compile_mode, num_threads): + def test_consolidate( + self, benchmark, td, compile_mode, num_threads, default_device + ): tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + # td = td.to(default_device) + def consolidate(td, num_threads): return td.consolidate(num_threads=num_threads) if compile_mode: consolidate = torch.compile( - consolidate, mode=compile_mode, dynamic=True, fullgraph=True + consolidate, mode=compile_mode, dynamic=False, fullgraph=True ) t0 = time.time() diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py index 7828c29f6..9338c566a 100644 --- a/benchmarks/compile/tensordict_nn_test.py +++ b/benchmarks/compile/tensordict_nn_test.py @@ -21,7 +21,14 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -torch.set_default_device(DEVICE) + +@pytest.fixture(scope="function", autouse=True) +def auto_device(): + device = torch.get_default_device() + torch.set_default_device(DEVICE) + yield + torch.set_default_device(device) + compile = functools.partial(torch.compile, fullgraph=True) compile_overhead = functools.partial( @@ -32,7 +39,10 @@ @pytest.fixture(scope="function", autouse=True) def reset_dynamo(): # Start a fresh compile for each parameter of the test case - torch._dynamo.reset() + try: + torch.compiler.reset() + except AttributeError: + torch._dynamo.reset() gc.collect() yield From 5467cbba16abb401498a3e3f228aff4546aa6116 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 31 Oct 2024 19:35:48 +0000 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 059601fb4..092d1e2b7 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -150,7 +150,7 @@ def consolidate(td, num_threads): ], ) @pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5" + TORCH_VERSION < version.parse("2.5.2"), reason="requires torch>=2.5" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA device found") class TestTo: From 29250b9b41f618d4086c56df8238791d13daebd1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 2 Nov 2024 21:21:31 +0000 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- benchmarks/common/h2d_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 092d1e2b7..34d85885d 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -185,6 +185,11 @@ def to(td, num_threads): def test_to_njt( self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads ): + if compile_mode: + pytest.skip( + "Compiling NJTs consolidation currently triggers a RuntimeError." + ) + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") pin_mem = default_device.type == "cuda" if consolidated is True: From d2b86c319b9cfa2083c851a7fa43a83062a02c3f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 2 Nov 2024 21:48:10 +0000 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- test/test_tensordict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index fcd1626e1..f84585777 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6392,6 +6392,8 @@ def test_to_device_dtype_inplace(self, td_name, device): ) elif td_name in ("permute_td", "unsqueezed_td", "squeezed_td", "td_h5"): cm = pytest.raises(TypeError, match="Cannot use inplace=True with") + elif td_name in ("memmap_td",) and dest.type == "cpu": + cm = contextlib.nullcontext() elif td.is_locked: cm = pytest.raises(RuntimeError, match="Cannot modify locked TensorDict.") else: From 597b5f9dafede3dcf4a72c8ac41166d9faf5a45e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 4 Nov 2024 08:04:06 +0000 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- test/test_tensordict.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f84585777..57d94ad88 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6380,29 +6380,36 @@ def test_to_device_dtype_inplace(self, td_name, device): td = getattr(self, td_name)(device) if torch.cuda.is_available(): dest = torch.device("cuda:0") - elif torch.mps.is_available(): - dest = torch.device("mps:0") + # elif torch.mps.is_available(): + # dest = torch.device("mps:0") else: dest = torch.device("cpu") if td_name in ("sub_td", "sub_td2"): - cm = pytest.raises( + cm_device = cm_dtype = pytest.raises( TypeError, match="Cannot send a _SubTensorDict instance to device/dtype inplace", ) elif td_name in ("permute_td", "unsqueezed_td", "squeezed_td", "td_h5"): - cm = pytest.raises(TypeError, match="Cannot use inplace=True with") + cm_device = cm_dtype = pytest.raises( + TypeError, match="Cannot use inplace=True with" + ) elif td_name in ("memmap_td",) and dest.type == "cpu": - cm = contextlib.nullcontext() + cm_device = contextlib.nullcontext() + cm_dtype = pytest.raises( + RuntimeError, match="Cannot modify locked TensorDict." + ) elif td.is_locked: - cm = pytest.raises(RuntimeError, match="Cannot modify locked TensorDict.") + cm_device = cm_dtype = pytest.raises( + RuntimeError, match="Cannot modify locked TensorDict." + ) else: - cm = contextlib.nullcontext() - with cm: + cm_device = cm_dtype = contextlib.nullcontext() + with cm_dtype: td.to(torch.float32, inplace=True) assert td.dtype == torch.float32, td - with cm: + with cm_device: td.to(dest, inplace=True) assert td.device == dest for v in td.values(