Skip to content

Commit 2ed9a70

Browse files
committed
amend
1 parent 86d6406 commit 2ed9a70

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

tensordict/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1975,7 +1975,7 @@ def memmap_like(
19751975
else:
19761976
return TensorDictFuture(futures, result)
19771977
input = self.apply(
1978-
lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand(x.shape)
1978+
lambda x: torch.empty_like(x)
19791979
)
19801980
return input._memmap_(
19811981
prefix=prefix,

tensordict/tensorclass.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
torch.full_like,
5757
torch.zeros_like,
5858
torch.ones_like,
59+
torch.empty_like,
60+
torch.randn_like,
61+
torch.rand_like,
5962
torch.clone,
6063
torch.squeeze,
6164
torch.unsqueeze,
@@ -1246,6 +1249,13 @@ class NonTensorData:
12461249
# and all the overhead falls back on this class.
12471250
data: Any
12481251

1252+
@classmethod
1253+
def from_tensor(cls, value: torch.Tensor, batch_size, device=None, names=None):
1254+
"""A util to create a NonTensorData containing a tensor."""
1255+
out = cls(data=None, batch_size=batch_size, device=device, names=names)
1256+
out._non_tensordict["data"] = value
1257+
return out
1258+
12491259
def __post_init__(self):
12501260
if isinstance(self.data, NonTensorData):
12511261
self.data = self.data.data
@@ -1304,7 +1314,7 @@ def __or__(self, other):
13041314
self.__class__.__or__ = __or__
13051315

13061316
def empty(self, recurse=False):
1307-
return NonTensorData(
1317+
return type(self)(
13081318
data=self.data,
13091319
batch_size=self.batch_size,
13101320
names=self.names if self._has_names() else None,
@@ -1332,7 +1342,7 @@ def _check_equal(a, b):
13321342
if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]):
13331343
batch_size = list(first.batch_size)
13341344
batch_size.insert(dim, len(list_of_non_tensor))
1335-
return NonTensorData(
1345+
return type(self)(
13361346
data=first.data,
13371347
batch_size=batch_size,
13381348
names=first.names if first._has_names() else None,
@@ -1358,7 +1368,7 @@ def __torch_function__(
13581368
):
13591369
return NotImplemented
13601370

1361-
escape_conversion = func in (torch.stack,)
1371+
escape_conversion = func in (torch.stack, torch.ones_like, torch.zeros_like, torch.empty_like, torch.randn_like, torch.rand_like)
13621372

13631373
if kwargs is None:
13641374
kwargs = {}

0 commit comments

Comments
 (0)