diff --git a/.github/workflows/test-rl-gpu.yml b/.github/workflows/test-rl-gpu.yml index ab284d754..dea057e8b 100644 --- a/.github/workflows/test-rl-gpu.yml +++ b/.github/workflows/test-rl-gpu.yml @@ -31,6 +31,7 @@ jobs: repository: pytorch/tensordict gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 120 script: | # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17518df1d..8e6e4e907 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -273,3 +273,4 @@ Here is an example: tensorclass NonTensorData + NonTensorStack diff --git a/tensordict/__init__.py b/tensordict/__init__.py index c71661ceb..5e6dc8761 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -16,7 +16,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership from tensordict.persistent import PersistentTensorDict -from tensordict.tensorclass import NonTensorData, tensorclass +from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, is_batchedtensor, @@ -43,6 +43,7 @@ "TensorDict", "TensorDictBase", "merge_tensordicts", + "NonTensorStack", "set_transfer_ownership", "pad_sequence", "is_memmap", diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index b0c6ee6cc..971ef142d 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -53,6 +53,7 @@ expand_right, IndexType, infer_size_impl, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, @@ -134,6 +135,8 @@ class LazyStackedTensorDict(TensorDictBase): `td.ndimension()-1` along which the stack should be performed. hook_out (callable, optional): a callable to execute after :meth:`~.get`. hook_in (callable, optional): a callable to execute before :meth:`~.set`. + stack_dim_name (str, optional): the name of the stack dimension. + Defaults to ``None``. Examples: >>> from tensordict import TensorDict @@ -184,6 +187,7 @@ def __init__( hook_out: callable | None = None, hook_in: callable | None = None, batch_size: Sequence[int] | None = None, # TODO: remove + stack_dim_name: str | None = None, ) -> None: self._is_locked = None @@ -200,6 +204,10 @@ def __init__( ) _batch_size = tensordicts[0].batch_size device = tensordicts[0].device + if stack_dim > len(_batch_size): + raise RuntimeError( + f"Stack dim {stack_dim} is too big for batch size {_batch_size}." + ) for td in tensordicts[1:]: if not is_tensor_collection(td): @@ -224,6 +232,8 @@ def __init__( self.hook_in = hook_in if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") + if stack_dim_name is not None: + self._td_dim_name = stack_dim_name # These attributes should never be set @property @@ -487,9 +497,10 @@ def _split_index(self, index): isinteger = False is_nd_tensor = False cursor = 0 # the dimension cursor - selected_td_idx = range(len(self.tensordicts)) + selected_td_idx = torch.arange(len(self.tensordicts)) has_bool = False num_squash = 0 + encountered_tensor = False for i, idx in enumerate(index): # noqa: B007 cursor_incr = 1 if idx is None: @@ -509,10 +520,8 @@ def _split_index(self, index): if not isinstance(selected_td_idx, range): isinteger = True selected_td_idx = [selected_td_idx] - elif isinstance(idx, (list, range)): - selected_td_idx = idx - elif isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: # we mark that we need to dispatch the indices across stack idx has_bool = True # split mask along dim @@ -522,11 +531,14 @@ def _split_index(self, index): split_dim = self.stack_dim - num_single mask_loc = i else: - if isinstance(idx, np.ndarray): - idx = torch.tensor(idx) is_nd_tensor = True - selected_td_idx = range(len(idx)) - out.append(idx.unbind(0)) + if not encountered_tensor: + # num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 + selected_td_idx = idx + # out.append(idx.unbind(0)) else: raise TypeError(f"Invalid index type: {type(idx)}.") else: @@ -537,13 +549,11 @@ def _split_index(self, index): ( ftdim.Dim, slice, - list, - range, ), ): out.append(idx) - elif isinstance(idx, (np.ndarray, torch.Tensor)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: cursor_incr = idx.ndim if cursor < self.stack_dim: num_squash += cursor_incr - 1 @@ -568,7 +578,11 @@ def _split_index(self, index): # smth[torch.tensor(1)].ndim = smth.ndim-1 # smth[torch.tensor([1])].ndim = smth.ndim # smth[torch.tensor([[1]])].ndim = smth.ndim+1 - num_single -= idx.ndim - 1 + if not encountered_tensor: + num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 out.append(idx) else: raise TypeError(f"Invalid index type: {type(idx)}.") @@ -593,20 +607,45 @@ def _split_index(self, index): elif is_nd_tensor: def isindexable(idx): - if isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (torch.bool, np.dtype("bool")): + if isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: return False return True if isinstance(idx, (tuple, list, range)): return True return False - out = tuple( - tuple(idx if not isindexable(idx) else idx[i] for idx in out) - for i in selected_td_idx - ) + def outer_list(tensor_index, tuple_index): + """Converts a tensor and a tuple to a nested list where each leaf is a (int, index) tuple where the index only points to one element.""" + if isinstance(tensor_index, torch.Tensor): + list_index = tensor_index.tolist() + else: + list_index = tensor_index + list_result = [] + + def index_tuple_index(i, convert=False): + for idx in tuple_index: + if isindexable(idx): + if convert: + yield int(idx[i]) + else: + yield idx[i] + else: + yield idx + + for i, idx in enumerate(list_index): + if isinstance(idx, int): + list_result.append( + (idx, tuple(index_tuple_index(i, convert=True))) + ) + elif isinstance(idx, list): + list_result.append(outer_list(idx, tuple(index_tuple_index(i)))) + else: + raise NotImplementedError + return list_result + return { - "index_dict": dict(enumerate(out)), + "index_dict": outer_list(selected_td_idx, out), "num_single": num_single, "isinteger": isinteger, "has_bool": has_bool, @@ -646,8 +685,19 @@ def _set_at_str(self, key, value, index, *, validated): if is_nd_tensor: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) - for idx, _value in zip(converted_idx.values(), value_unbind): - self._set_at_str(key, _value, idx, validated=validated) + + def set_at_str(converted_idx): + for i, item in enumerate(converted_idx): + if isinstance(item, list): + set_at_str(item) + else: + _value = value_unbind[i] + stack_idx, idx = item + self.tensordicts[stack_idx]._set_at_str( + key, _value, idx, validated=validated + ) + + set_at_str(converted_idx) return self elif not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash @@ -718,7 +768,7 @@ def _legacy_unsqueeze(self, dim: int) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.unsqueeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -756,7 +806,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.squeeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -810,7 +860,9 @@ def _get_str( # then we consider this default as non-stackable and return prematurly return default try: - out = self.lazy_stack(tensors, self.stack_dim) + out = self.lazy_stack( + tensors, self.stack_dim, stack_dim_name=self._td_dim_name + ) if _is_tensor_collection(out.__class__): if isinstance(out, LazyStackedTensorDict): # then it's a LazyStackedTD @@ -822,8 +874,6 @@ def _get_str( self._batch_size + out.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._td_dim_name = self._td_dim_name elif is_tensorclass(out): # then it's a tensorclass out._tensordict.hook_out = self.hook_out @@ -834,8 +884,6 @@ def _get_str( self._batch_size + out._tensordict.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._tensordict._td_dim_name = self._td_dim_name else: raise RuntimeError elif self.hook_out is not None: @@ -880,25 +928,30 @@ def lazy_stack( cls, items: Sequence[TensorDictBase], dim: int = 0, + *, device: DeviceType | None = None, out: T | None = None, + stack_dim_name: str | None = None, ) -> T: """Stacks tensordicts in a LazyStackedTensorDict.""" if not items: raise RuntimeError("items cannot be empty") - from .tensorclass import NonTensorData - if all(isinstance(item, torch.Tensor) for item in items): return torch.stack(items, dim=dim, out=out) if all( is_tensorclass(item) and type(item) == type(items[0]) # noqa: E721 for item in items ): - if all(isinstance(tensordict, NonTensorData) for tensordict in items): + if all(is_non_tensor(tensordict) for tensordict in items): + from .tensorclass import NonTensorData + return NonTensorData._stack_non_tensor(items, dim=dim) lazy_stack = cls.lazy_stack( - [item._tensordict for item in items], dim=dim, out=out + [item._tensordict for item in items], + dim=dim, + out=out, + stack_dim_name=stack_dim_name, ) # we take the first non_tensordict by convention return type(items[0])._from_tensordict( @@ -923,7 +976,9 @@ def lazy_stack( # The first case is handled within _check_keys which fails if keys # don't match exactly. # The second requires a check over the tensor shapes. - return LazyStackedTensorDict(*items, stack_dim=dim) + return LazyStackedTensorDict( + *items, stack_dim=dim, stack_dim_name=stack_dim_name + ) else: batch_size = list(batch_size) batch_size.insert(dim, len(items)) @@ -1236,7 +1291,7 @@ def contiguous(self) -> T: return out def empty(self, recurse=False) -> T: - return LazyStackedTensorDict( + return type(self)( *[td.empty(recurse=recurse) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1245,17 +1300,17 @@ def _clone(self, recurse: bool = True) -> T: if recurse: # This could be optimized using copy but we must be careful with # metadata (_is_shared etc) - result = LazyStackedTensorDict( + result = type(self)( *[td._clone() for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: - result = LazyStackedTensorDict( + result = type(self)( *[td._clone(recurse=False) for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - if self._td_dim_name is not None: - result._td_dim_name = self._td_dim_name return result def pin_memory(self) -> T: @@ -1274,7 +1329,7 @@ def to(self, *args, **kwargs) -> T: if device is not None and dtype is None and device == self.device: return result - return LazyStackedTensorDict( + return type(self)( *[td.to(*args, **kwargs) for td in self.tensordicts], stack_dim=self.stack_dim, hook_out=self.hook_out, @@ -1403,16 +1458,15 @@ def _apply_nest( if filter_empty and all(r is None for r in results): return if not inplace: - out = LazyStackedTensorDict( + out = type(self)( *results, stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: out = self if names is not None: out.names = names - else: - out._td_dim_name = self._td_dim_name return out def _select( @@ -1429,7 +1483,7 @@ def _select( ] if inplace: return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def _exclude( @@ -1442,7 +1496,7 @@ def _exclude( if inplace: self.tensordicts = tensordicts return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def __setitem__(self, index: IndexType, value: T) -> T: @@ -1460,10 +1514,12 @@ def __setitem__(self, index: IndexType, value: T) -> T: ) return - if any(isinstance(sub_index, (list, range)) for sub_index in index): + if any( + isinstance(sub_index, (list, range, np.ndarray)) for sub_index in index + ): index = tuple( - torch.tensor(sub_index, device=self.device) - if isinstance(sub_index, (list, range)) + torch.as_tensor(sub_index, device=self.device) + if isinstance(sub_index, (list, range, np.ndarray)) else sub_index for sub_index in index ) @@ -1471,9 +1527,9 @@ def __setitem__(self, index: IndexType, value: T) -> T: if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index): index = convert_ellipsis_to_idx(index, self.batch_size) elif isinstance(index, (list, range)): - index = torch.tensor(index, device=self.device) + index = torch.as_tensor(index, device=self.device) - if isinstance(value, (TensorDictBase, dict)): + if is_tensor_collection(value) or isinstance(value, dict): indexed_bs = _getitem_batch_size(self.batch_size, index) if isinstance(value, dict): value = TensorDict( @@ -1500,13 +1556,26 @@ def __setitem__(self, index: IndexType, value: T) -> T: if isinteger: # this will break if the index along the stack dim is [0] or :1 or smth for i, _idx in converted_idx.items(): - self.tensordicts[i][_idx] = value + if _idx == (): + self.tensordicts[i].update(value, inplace=True) + else: + self.tensordicts[i][_idx] = value return self if is_nd_tensor: - raise RuntimeError( - "Indexing along stack dim with a non-boolean tensor is not supported yet. " - "Use SubTensorDict instead." - ) + unbind_dim = self.stack_dim - num_single + num_none - num_squash + + # converted_idx is a nested list with (int, index) items + def assign(converted_idx, value=value): + value = value.unbind(unbind_dim) + for i, item in enumerate(converted_idx): + if isinstance(item, list): + assign(item) + else: + stack_item, idx = item + self.tensordicts[stack_item][idx] = value[i] + + assign(converted_idx) + return self if not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) @@ -1514,7 +1583,10 @@ def __setitem__(self, index: IndexType, value: T) -> T: converted_idx.items(), value_unbind, ): - self.tensordicts[i][_idx] = _value + if _idx == (): + self.tensordicts[i].update(_value, inplace=True) + else: + self.tensordicts[i][_idx] = _value else: # we must split, not unbind mask_unbind = split_index["individual_masks"] @@ -1582,11 +1654,22 @@ def __getitem__(self, index: IndexType) -> T: return torch.cat(result, cat_dim) elif is_nd_tensor: new_stack_dim = self.stack_dim - num_single + num_none - out = LazyStackedTensorDict.lazy_stack( - [self[idx] for idx in converted_idx.values()], new_stack_dim - ) - out._td_dim_name = self._td_dim_name - return out + + def recompose(converted_idx, stack_dim=new_stack_dim): + stack = [] + for item in converted_idx: + if isinstance(item, list): + stack.append(recompose(item, stack_dim=stack_dim)) + else: + stack_elt, idx = item + stack.append(self.tensordicts[stack_elt][idx]) + # TODO: this produces multiple dims with the same name + result = LazyStackedTensorDict.lazy_stack( + stack, stack_dim, stack_dim_name=self._td_dim_name + ) + return result + + return recompose(converted_idx) else: if isinteger: for ( @@ -1603,9 +1686,13 @@ def __getitem__(self, index: IndexType) -> T: result = [] new_stack_dim = self.stack_dim - num_single + num_none - num_squash for i, _idx in converted_idx.items(): - result.append(self.tensordicts[i][_idx]) - result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim) - result._td_dim_name = self._td_dim_name + if _idx == (): + result.append(self.tensordicts[i]) + else: + result.append(self.tensordicts[i][_idx]) + result = LazyStackedTensorDict.lazy_stack( + result, new_stack_dim, stack_dim_name=self._td_dim_name + ) return result def __eq__(self, other): @@ -2329,9 +2416,9 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=1, dim0=1, dim1=4 # resulting shape: [5, 1, 3, 2, 4] if dim1 == dim0 + 1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1) + result = type(self)(*self.tensordicts, stack_dim=dim1) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1 - 1) for td in self.tensordicts), stack_dim=dim1, ) @@ -2339,20 +2426,20 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=3, dim0=1, dim1=3 # resulting shape: [5, 2, 3, 4, 1] if dim0 + 1 == dim1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0) + result = type(self)(*self.tensordicts, stack_dim=dim0) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0 + 1, dim1) for td in self.tensordicts), stack_dim=dim0, ) else: dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1 dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1 - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1) for td in self.tensordicts), stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _permute( @@ -2370,9 +2457,10 @@ def _permute( d if d < self.stack_dim else d - 1 for d in dims_list if d != self.stack_dim ] result = LazyStackedTensorDict.lazy_stack( - [td.permute(dims_list) for td in self.tensordicts], stack_dim + [td.permute(dims_list) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _squeeze(self, dim=None): @@ -2395,9 +2483,10 @@ def _squeeze(self, dim=None): else: stack_dim = self.stack_dim - 1 result = LazyStackedTensorDict.lazy_stack( - [td.squeeze(dim) for td in self.tensordicts], stack_dim + [td.squeeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name else: result = self for dim in range(self.batch_dims - 1, -1, -1): @@ -2420,9 +2509,10 @@ def _unsqueeze(self, dim): else: stack_dim = self.stack_dim + 1 result = LazyStackedTensorDict.lazy_stack( - [td.unsqueeze(dim) for td in self.tensordicts], stack_dim + [td.unsqueeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name return result lock_ = TensorDictBase.lock_ diff --git a/tensordict/_td.py b/tensordict/_td.py index c3b5d85dd..021c9151e 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -71,6 +71,7 @@ DeviceType, expand_as_right, IndexType, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, @@ -308,9 +309,8 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False - from tensordict.tensorclass import NonTensorData - if isinstance(item, NonTensorData): + if is_non_tensor(item): return False else: return False @@ -689,7 +689,11 @@ def make_result(): any_set = False for key, item in self.items(): - if not call_on_nested and _is_tensor_collection(item.__class__): + if ( + not call_on_nested + and _is_tensor_collection(item.__class__) + # and not is_non_tensor(item) + ): if default is not NO_DEFAULT: _others = [_other._get_str(key, default=None) for _other in others] _others = [ @@ -1557,7 +1561,9 @@ def _set_at_str(self, key, value, idx, *, validated): tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) + if tensor_in is not tensor_out: + self._set_str(key, tensor_out, validated=True, inplace=False) return self @@ -2385,10 +2391,11 @@ def _set_at_str(self, key, value, idx, *, validated): ) tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) + tensor_out = tensor_in else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) # make sure that the value is updated - self._source._set_at_str(key, tensor_in, self.idx, validated=validated) + self._source._set_at_str(key, tensor_out, self.idx, validated=validated) return self def _set_at_tuple(self, key, value, idx, *, validated): @@ -2449,15 +2456,17 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) - from tensordict.tensorclass import NonTensorData - if isinstance(out, _SubTensorDict) and isinstance(out._source, NonTensorData): - return out._source.data + if isinstance(out, _SubTensorDict) and is_non_tensor(out._source): + return out._source return out def _get_str(self, key, default): if key in self.keys() and _is_tensor_collection(self.entry_class(key)): - return _SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx) + data = self._source._get_str(key, NO_DEFAULT) + if is_non_tensor(data): + return data[self.idx] + return _SubTensorDict(data, self.idx) return self._source._get_at_str(key, self.idx, default=default) def _get_tuple(self, key, default): @@ -3031,7 +3040,10 @@ def __contains__(self, key: NestedKey) -> bool: if isinstance(key, str): if key in self._keys(): if self.leaves_only: - return not _is_tensor_collection(self.tensordict.entry_class(key)) + # TODO: make this faster for LazyStacked without compromising regular + return not _is_tensor_collection( + type(self.tensordict._get_str(key)) + ) return True return False else: @@ -3039,25 +3051,30 @@ def __contains__(self, key: NestedKey) -> bool: if len(key) == 1: return key[0] in self._keys() elif self.include_nested: - if key[0] in self._keys(): - entry_type = self.tensordict.entry_class(key[0]) - if entry_type in (Tensor, _MemmapTensor): + item_root = self.tensordict._get_str(key[0], default=None) + if item_root is not None: + entry_type = type(item_root) + if issubclass(entry_type, (Tensor, _MemmapTensor)): return False - if entry_type is KeyedJaggedTensor: + elif entry_type is KeyedJaggedTensor: if len(key) > 2: return False - return key[1] in self.tensordict.get(key[0]).keys() + return key[1] in item_root.keys() + # TODO: make this faster for LazyStacked without compromising regular _is_tensordict = _is_tensor_collection(entry_type) if _is_tensordict: # # this will call _unravel_key_to_tuple many times # return key[1:] in self.tensordict._get_str(key[0], NO_DEFAULT).keys(include_nested=self.include_nested) # this won't call _unravel_key_to_tuple but requires to get the default which can be suboptimal - leaf_td = self.tensordict._get_tuple(key[:-1], None) - if leaf_td is None or ( - not _is_tensor_collection(leaf_td.__class__) - and not isinstance(leaf_td, KeyedJaggedTensor) - ): - return False + if len(key) >= 3: + leaf_td = item_root._get_tuple(key[1:-1], None) + if leaf_td is None or ( + not _is_tensor_collection(leaf_td.__class__) + and not isinstance(leaf_td, KeyedJaggedTensor) + ): + return False + else: + leaf_td = item_root return key[-1] in leaf_td.keys() return False # this is reached whenever there is more than one key but include_nested is False diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index b960ea843..fbabc77e4 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -19,6 +19,7 @@ _check_keys, _ErrorInteceptor, DeviceType, + is_non_tensor, lazy_legacy, set_lazy_legacy, ) @@ -163,6 +164,32 @@ def _ones_like(td: T, **kwargs: Any) -> T: return td_clone +@implements_for_td(torch.rand_like) +def _rand_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.rand_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + +@implements_for_td(torch.randn_like) +def _randn_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.randn_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + @implements_for_td(torch.empty_like) def _empty_like(td: T, *args, **kwargs) -> T: try: @@ -355,9 +382,9 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - from tensordict.tensorclass import NonTensorData + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData - if all(isinstance(td, NonTensorData) for td in list_of_tensordicts): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) batch_size = list_of_tensordicts[0].batch_size diff --git a/tensordict/base.py b/tensordict/base.py index 9203b6ba3..682e80afd 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -60,6 +60,7 @@ IndexType, infer_size_impl, int_generator, + is_non_tensor, KeyedJaggedTensor, lazy_legacy, lock_blocked, @@ -242,10 +243,11 @@ def __getitem__(self, index: IndexType) -> T: idx_unravel = _unravel_key_to_tuple(index) if idx_unravel: result = self._get_tuple(idx_unravel, NO_DEFAULT) - from .tensorclass import NonTensorData - - if isinstance(result, NonTensorData): - return result.data + if is_non_tensor(result): + result_data = getattr(result, "data", NO_DEFAULT) + if result_data is NO_DEFAULT: + return result.tolist() + return result_data return result if (istuple and not index) or (not istuple and index is Ellipsis): @@ -1655,6 +1657,14 @@ def cuda(self, device: int = None) -> T: return self.to(torch.device("cuda")) return self.to(f"cuda:{device}") + @property + def is_cuda(self): + return self.device is not None and self.device.type == "cuda" + + @property + def is_cpu(self): + return self.device is not None and self.device.type == "cpu" + # Serialization functionality def state_dict( self, @@ -2118,9 +2128,7 @@ def memmap_like( return result else: return TensorDictFuture(futures, result) - input = self.apply( - lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand(x.shape) - ) + input = self.apply(lambda x: torch.empty_like(x)) return input._memmap_( prefix=prefix, copy_existing=copy_existing, @@ -2308,18 +2316,19 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from tensordict.tensorclass import NonTensorData - if isinstance(value, NonTensorData): - return value.data + if is_non_tensor(value): + data = getattr(value, "data", None) + if data is None: + return value.tolist() + return data return value def filter_non_tensor_data(self) -> T: """Filters out all non-tensor-data.""" - from tensordict.tensorclass import NonTensorData def _filter(x): - if not isinstance(x, NonTensorData): + if not is_non_tensor(x): if is_tensor_collection(x): return x.filter_non_tensor_data() return x @@ -5376,10 +5385,8 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict.tensorclass import NonTensorData - if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, NonTensorData) + return cls.__dict__.get("_non_tensor", False) return issubclass(cls, torch.Tensor) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 81a124c81..e5714bf4e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -1117,8 +1117,6 @@ def compute_should_use_set_data(tensor, tensor_applied): param.data = param_applied out_param = param else: - assert isinstance(param, nn.Parameter) - assert param.is_leaf out_param = nn.Parameter(param_applied, param.requires_grad) self._parameters[key] = out_param @@ -1129,10 +1127,8 @@ def compute_should_use_set_data(tensor, tensor_applied): param.grad, grad_applied ) if should_use_set_data: - assert out_param.grad is not None out_param.grad.data = grad_applied else: - assert param.grad.is_leaf out_param.grad = grad_applied.requires_grad_( param.grad.requires_grad ) @@ -1150,8 +1146,6 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.data = buffer_applied out_buffer = buffer else: - assert isinstance(buffer, Buffer) - assert buffer.is_leaf out_buffer = Buffer(buffer_applied, buffer.requires_grad) self._buffers[key] = out_buffer @@ -1162,10 +1156,8 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.grad, grad_applied ) if should_use_set_data: - assert out_buffer.grad is not None out_buffer.grad.data = grad_applied else: - assert buffer.grad.is_leaf out_buffer.grad = grad_applied.requires_grad_( buffer.grad.requires_grad ) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 3e1d698bd..105be6538 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -14,7 +14,7 @@ import torch from torch import nn -AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "True")) +AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "False")) DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2c0397422..0af45cea4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -15,7 +15,7 @@ import re import sys import warnings -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass from pathlib import Path from textwrap import indent @@ -24,10 +24,11 @@ import tensordict as tensordict_lib import torch +from tensordict import LazyStackedTensorDict from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -36,6 +37,7 @@ _LOCK_ERROR, DeviceType, IndexType, + is_non_tensor, is_tensorclass, NestedKey, ) @@ -56,6 +58,9 @@ torch.full_like, torch.zeros_like, torch.ones_like, + torch.rand_like, + torch.empty_like, + torch.randn_like, torch.clone, torch.squeeze, torch.unsqueeze, @@ -216,6 +221,7 @@ def __torch_function__( cls.to_tensordict = _to_tensordict cls.device = property(_device, _device_setter) cls.batch_size = property(_batch_size, _batch_size_setter) + cls.names = property(_names, _names_setter) cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}" @@ -483,7 +489,7 @@ def wrapper(self, item: str) -> Any: return wrapper -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts") +SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -722,6 +728,9 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False): __dict__ = self.__dict__ if __dict__["_tensordict"].is_locked: raise RuntimeError(_LOCK_ERROR) + if key in ("batch_size", "names", "device"): + # handled by setattr + return expected_keys = self.__dataclass_fields__ if key not in expected_keys: raise AttributeError( @@ -838,6 +847,26 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 self._tensordict._batch_size_setter(new_size) +def _names(self) -> torch.Size: + """Retrieves the dim names for the tensor class. + + Returns: + names (list of str) + + """ + return self._tensordict.names + + +def _names_setter(self, names: str) -> None: # noqa: D417 + """Set the value of ``tensorclass.names``. + + Args: + names (sequence of str) + + """ + self._tensordict.names = names + + def _state_dict( self, destination=None, prefix="", keep_vars=False, flatten=False ) -> dict[str, Any]: @@ -1256,10 +1285,21 @@ class NonTensorData: # to patch tensordict with additional checks that will encur unwanted overhead # and all the overhead falls back on this class. data: Any + _non_tensor: bool = True + + @classmethod + def from_tensor(cls, value: torch.Tensor, batch_size, device=None, names=None): + """A util to create a NonTensorData containing a tensor.""" + out = cls(data=None, batch_size=batch_size, device=device, names=names) + out._non_tensordict["data"] = value + return out def __post_init__(self): - if isinstance(self.data, NonTensorData): - self.data = self.data.data + if is_non_tensor(self.data): + data = getattr(self.data, "data", None) + if data is None: + data = self.data.tolist() + self.data = data old_eq = self.__class__.__eq__ if old_eq is _eq: @@ -1314,8 +1354,25 @@ def __or__(self, other): self.__class__.__or__ = __or__ + def update( + self, + input_dict_or_td: dict[str, CompatibleType] | T, + clone: bool = False, + inplace: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, + ) -> T: + if isinstance(input_dict_or_td, NonTensorData): + data = input_dict_or_td.data + if clone: + data = deepcopy(data) + self.data = data + elif not input_dict_or_td.is_empty(): + raise RuntimeError(f"Unexpected type {type(input_dict_or_td)}") + return self + def empty(self, recurse=False): - return NonTensorData( + return type(self)( data=self.data, batch_size=self.batch_size, names=self.names if self._has_names() else None, @@ -1340,19 +1397,19 @@ def _check_equal(a, b): iseq = False return iseq - if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): + if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all( + _check_equal(data.data, first.data) for data in list_of_non_tensor[1:] + ): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) - return NonTensorData( + return type(cls)( data=first.data, batch_size=batch_size, names=first.names if first._has_names() else None, device=first.device, ) - from tensordict._lazy import LazyStackedTensorDict - - return LazyStackedTensorDict(*list_of_non_tensor, stack_dim=dim) + return NonTensorStack(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( @@ -1369,7 +1426,14 @@ def __torch_function__( ): return NotImplemented - escape_conversion = func in (torch.stack,) + escape_conversion = func in ( + torch.stack, + torch.ones_like, + torch.zeros_like, + torch.empty_like, + torch.randn_like, + torch.rand_like, + ) if kwargs is None: kwargs = {} @@ -1406,3 +1470,100 @@ def _fast_apply(self, *args, **kwargs): return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)( *args, **kwargs ) + + def tolist(self): + """Converts the data in a list if the batch-size is non-empty. + + If the batch-size is empty, returns the data. + + """ + if not self.batch_size: + return self.data + return [ntd.tolist() for ntd in self.unbind(0)] + + def copy_(self, src: NonTensorData | NonTensorStack, non_blocking: bool = False): + if isinstance(src, NonTensorStack): + raise RuntimeError( + "Cannot update a NonTensorData with a NonTensorStack object." + ) + if not isinstance(src, NonTensorData): + raise RuntimeError( + "NonTensorData.copy_ requires the source to be a NonTensorData object." + ) + self._non_tensordict["data"] = src.data + + def clone(self, recurse: bool = True): + if recurse: + return type(self)( + data=deepcopy(self.data), + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + return type(self)( + data=self.data, + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + + +class NonTensorStack(LazyStackedTensorDict): + """A thin wrapper around LazyStackedTensorDict to make stack on non-tensor data easily recognizable. + + A ``NonTensorStack`` is returned whenever :func:`~torch.stack` is called on + a list of :class:`~tensordict.NonTensorData` or ``NonTensorStack``. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> print(data) + NonTensorStack( + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, ..., + batch_size=torch.Size([3, 2]), + device=None) + + To obtain the values stored in a ``NonTensorStack``, call :class:`~.tolist`. + + """ + + _non_tensor: bool = True + + def tolist(self): + """Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> data.tolist() + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, 2)]] + + """ + iterator = self.tensordicts if self.stack_dim == 0 else self.unbind(0) + return [td.tolist() for td in iterator] + + @classmethod + def from_nontensordata(cls, non_tensor: NonTensorData): + data = non_tensor.data + prev = NonTensorData(data, batch_size=[], device=non_tensor.device) + for dim in reversed(non_tensor.shape): + prev = cls(*[prev.clone(False) for _ in range(dim)], stack_dim=0) + return prev + + def __repr__(self): + selfrepr = str(self.tolist()) + if len(selfrepr) > 50: + selfrepr = f"{selfrepr[:50]}..." + selfrepr = indent(selfrepr, prefix=4 * " ") + batch_size = indent(f"batch_size={self.batch_size}", prefix=4 * " ") + device = indent(f"device={self.device}", prefix=4 * " ") + return f"NonTensorStack(\n{selfrepr}," f"\n{batch_size}," f"\n{device})" + + def to_dict(self) -> dict[str, Any]: + return self.tolist() diff --git a/tensordict/utils.py b/tensordict/utils.py index e72163b1b..c64c30a7a 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -49,6 +49,7 @@ unravel_key_list, unravel_keys, ) + from torch import Tensor from torch._C import _disabled_torch_function_impl from torch.nn.parameter import ( @@ -59,7 +60,6 @@ ) from torch.utils.data._utils.worker import _generate_state - if TYPE_CHECKING: from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.tensordict import TensorDictBase @@ -659,6 +659,21 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> elif isinstance(tensor, KeyedJaggedTensor): tensor = setitem_keyedjaggedtensor(tensor, index, value) return tensor + from tensordict.tensorclass import NonTensorData, NonTensorStack + + if is_non_tensor(tensor): + if ( + isinstance(value, NonTensorData) + and isinstance(tensor, NonTensorData) + and tensor.data == value.data + ): + return tensor + elif isinstance(tensor, NonTensorData): + tensor = NonTensorStack.from_nontensordata(tensor) + if tensor.stack_dim != 0: + tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0) + tensor[index] = value + return tensor else: tensor[index] = value return tensor @@ -1506,9 +1521,7 @@ def _expand_to_match_shape( def _set_max_batch_size(source: T, batch_dims=None): """Updates a tensordict with its maximium batch size.""" - from tensordict import NonTensorData - - tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)] + tensor_data = [val for val in source.values() if not is_non_tensor(val)] for val in tensor_data: from tensordict.base import _is_tensor_collection @@ -1587,7 +1600,7 @@ def wrapper(*args, **kwargs): def _broadcast_tensors(index): # tensors and range need to be broadcast tensors = { - i: tensor if isinstance(tensor, Tensor) else torch.tensor(tensor) + i: torch.as_tensor(tensor) for i, tensor in enumerate(index) if isinstance(tensor, (range, list, np.ndarray, Tensor)) } @@ -2156,3 +2169,8 @@ def __call__(self, mod: torch.nn.Module, args, kwargs): return else: raise RuntimeError("did not find pre-hook") + + +def is_non_tensor(data): + """Checks if an item is a non-tensor.""" + return type(data).__dict__.get("_non_tensor", False) diff --git a/test/test_nn.py b/test/test_nn.py index 37d5975b0..016237077 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -415,6 +415,7 @@ def test_functional_before(self): tensordict_module = TensorDictModule( module=net, in_keys=["in"], out_keys=["out"] ) + make_functional(tensordict_module, return_params=False) td = TensorDict({"in": torch.randn(3, 3)}, [3]) tensordict_module(td, params=TensorDict({"module": params}, [])) @@ -580,6 +581,7 @@ def test_functional_with_buffer(self): tdmodule = TensorDictModule(module=net, in_keys=["in"], out_keys=["out"]) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) + make_functional(tdmodule, return_params=False) tdmodule(td, params=TensorDict({"module": params}, [])) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 819556a6f..66e2875c8 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2774,6 +2774,7 @@ def test_index_tensor_nd_names(self, td_name, device, npy): index = index.numpy() td_idx = td[:, index] assert tensor_example[:, index].shape == td_idx.shape + # TODO: this multiple dims with identical names should not be allowed assert td_idx.names == [names[0], names[1], names[1], *names[2:]] td_idx = td[0, index] assert tensor_example[0, index].shape == td_idx.shape @@ -6102,8 +6103,16 @@ def test_lazy_indexing(self, pos1, pos2, pos3): index = (pos1, pos2, pos3) result = outer[index] ref_tensor = torch.zeros(outer.shape) - assert result.batch_size == ref_tensor[index].shape, index - assert result.batch_size == outer_dense[index].shape, index + assert result.batch_size == ref_tensor[index].shape, ( + result.batch_size, + ref_tensor[index].shape, + index, + ) + assert result.batch_size == outer_dense[index].shape, ( + result.batch_size, + outer_dense[index].shape, + index, + ) @pytest.mark.parametrize("stack_dim", [0, 1, 2]) @pytest.mark.parametrize("mask_dim", [0, 1, 2]) @@ -7788,6 +7797,29 @@ def test_map_with_out(self, mmap, chunksize, tmpdir): input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out) assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap) + @classmethod + def nontensor_check(cls, td): + td["check"] = td["non_tensor"] == ( + "a string!" if (td["tensor"] % 2) == 0 else "another string!" + ) + return td + + def test_non_tensor(self): + # with NonTensorStack + td = TensorDict( + {"tensor": torch.arange(10), "non_tensor": "a string!"}, batch_size=[10] + ) + td[1::2] = TensorDict({"non_tensor": "another string!"}, [5]) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # with NonTensorData + td = TensorDict( + {"tensor": torch.zeros(10, dtype=torch.int), "non_tensor": "a string!"}, + batch_size=[10], + ) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # class TestNonTensorData: class TestNonTensorData: @@ -7855,6 +7887,43 @@ def test_stack(self, non_tensor_data): LazyStackedTensorDict, ) + def test_assign_non_tensor(self): + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + assert data["b"] == "a string!" + assert data.get("b").tolist() == [["a string!"] * 10] + data[0, 1] = TensorDict({"a": 0, "b": "another string!"}, []) + assert data.get("b").tolist() == [ + ["a string!"] + ["another string!"] + ["a string!"] * 8 + ] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 5:] = TensorDict({"a": torch.zeros(5), "b": "another string!"}, [5]) + assert data.get("b").tolist() == [["a string!"] * 5 + ["another string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 0::2] = TensorDict( + {"a": torch.zeros(5, dtype=torch.long), "b": "another string!"}, [5] + ) + assert data.get("b").tolist() == [["another string!", "a string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0] = TensorDict( + {"a": torch.zeros(10, dtype=torch.long), "b": "another string!"}, [10] + ) + assert data.get("b").tolist() == [["another string!"] * 10] + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()