Skip to content

Commit 01fdc9d

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: 25f5d9a Pull Request resolved: #1041
1 parent e696708 commit 01fdc9d

File tree

1 file changed

+124
-2
lines changed

1 file changed

+124
-2
lines changed

tensordict/base.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521

35223522
flat_size = []
35233523
start = 0
3524+
sorting_index = 0
35243525

35253526
def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3526-
nonlocal start
3527+
nonlocal start, sorting_index
35273528
n = value.element_size() * value.numel()
35283529
if need_padding:
35293530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542
start,
35423543
stop,
35433544
pad,
3545+
flat_size[-1],
3546+
sorting_index,
35443547
)
3548+
sorting_index = sorting_index + 1
35453549
start = stop
35463550

35473551
def assign(
@@ -10390,7 +10394,7 @@ def to(self, *args, **kwargs) -> T:
1039010394
return result
1039110395

1039210396
if self.is_consolidated() and dtype is None:
10393-
return self._to_consolidated(
10397+
return self._to_consolidated_compile(
1039410398
device=device,
1039510399
pin_memory=non_blocking_pin,
1039610400
num_threads=num_threads,
@@ -10542,6 +10546,124 @@ def copy_dict(d):
1054210546

1054310547
return result
1054410548

10549+
def _to_consolidated_compile(self, *, device, pin_memory, num_threads, non_blocking):
10550+
10551+
def get_l(metadata, lengths=None, pos=None, keys=None, prefix=()):
10552+
root = False
10553+
if lengths is None:
10554+
lengths = []
10555+
pos = []
10556+
keys = []
10557+
root = True
10558+
for k, v in metadata["leaves"].items():
10559+
lengths.append(v[-2])
10560+
pos.append(v[-1])
10561+
keys.append(prefix + (k,))
10562+
for k, d in metadata.items():
10563+
if "leaves" in d:
10564+
get_l(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,))
10565+
if root:
10566+
# l = torch.empty(len(lengths), dtype=torch.long)
10567+
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10568+
out0 = [None, ] * len(pos)
10569+
out1 = [None, ] * len(pos)
10570+
for p, l, k in zip(pos, lengths, keys):
10571+
out0[p] = k
10572+
out1[p] = l
10573+
return out0, out1
10574+
10575+
def split_storage(consolidated):
10576+
keys, splits = get_l(consolidated["metadata"])
10577+
return dict(zip(keys, consolidated["storage"].split(splits)))
10578+
10579+
if num_threads is None:
10580+
# unspecified num_threads should mean 0
10581+
num_threads = 0
10582+
storage = self._consolidated["storage"]
10583+
if pin_memory:
10584+
storage = storage.pin_memory()
10585+
storage_cast = storage.to(device, non_blocking=True)
10586+
10587+
_consolidated = {"storage": storage_cast}
10588+
if "metadata" in self._consolidated:
10589+
# faster than deepcopy
10590+
def copy_dict(d):
10591+
return {
10592+
k: v if not isinstance(v, dict) else copy_dict(v)
10593+
for k, v in d.items()
10594+
}
10595+
10596+
_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10597+
10598+
slice_map = split_storage(_consolidated)
10599+
10600+
def set_(name, x):
10601+
if not isinstance(name, tuple):
10602+
name = (name,)
10603+
if x.is_nested:
10604+
from torch._subclasses.fake_tensor import FakeTensor
10605+
from torch._subclasses.functional_tensor import FunctionalTensor
10606+
from torch.nested._internal.nested_tensor import (
10607+
_tensor_symint_registry,
10608+
NestedTensor,
10609+
)
10610+
from torch.nested._internal.ops import extract_kwargs
10611+
10612+
if x.layout != torch.jagged:
10613+
raise RuntimeError(
10614+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10615+
"Please raise an issue on GitHub."
10616+
)
10617+
kwargs = extract_kwargs(x)
10618+
values = x._values
10619+
lengths = x._lengths
10620+
offsets = x._offsets
10621+
kwargs["offsets"] = slice_map[(*name[:-1], "<NJT_OFFSETS>"+name[-1],)].view(offsets.dtype).view(offsets.shape)
10622+
if lengths is not None:
10623+
kwargs["lengths"] = slice_map[(*name[:-1], "<NJT_LENGTHS>"+name[-1],)].view(lengths.dtype).view(lengths.shape)
10624+
ragged_source = lengths
10625+
else:
10626+
ragged_source = offsets
10627+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10628+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10629+
from torch._subclasses.functional_tensor import (
10630+
mb_unwrap_functional_tensor,
10631+
)
10632+
10633+
# Temporary hack until we have the union find
10634+
tgt = mb_unwrap_functional_tensor(new_thing)
10635+
src = mb_unwrap_functional_tensor(ragged_source)
10636+
tgt.nested_int_memo = src.nested_int_memo
10637+
else:
10638+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10639+
ragged_source
10640+
]
10641+
10642+
return NestedTensor(
10643+
slice_map[(*name[:-1], "<NJT_VALUES>"+name[-1],)].view(values.dtype).view(values.shape),
10644+
**kwargs,
10645+
)
10646+
return slice_map[name].view(x.dtype).view(x.shape)
10647+
10648+
result = self._fast_apply(
10649+
set_, device=torch.device(device), num_threads=num_threads, named=True, nested_keys=True,
10650+
)
10651+
result._consolidated = _consolidated
10652+
10653+
if non_blocking in (False, None):
10654+
if device.type == "cuda" and non_blocking is False:
10655+
# sending to CUDA force sync
10656+
cuda_device = device
10657+
elif storage.device.type == "cuda":
10658+
# sending from cuda: need sync unless intentionally not asked for
10659+
cuda_device = storage.device.type
10660+
else:
10661+
cuda_device = None
10662+
if cuda_device is not None:
10663+
torch.cuda.current_stream(cuda_device).synchronize()
10664+
10665+
return result
10666+
1054510667
def _sync_all(self):
1054610668
if _has_cuda:
1054710669
# TODO: dynamo doesn't like torch.cuda.is_initialized

0 commit comments

Comments
 (0)