56
56
torch .full_like ,
57
57
torch .zeros_like ,
58
58
torch .ones_like ,
59
+ torch .empty_like ,
60
+ torch .randn_like ,
61
+ torch .rand_like ,
59
62
torch .clone ,
60
63
torch .squeeze ,
61
64
torch .unsqueeze ,
@@ -1246,6 +1249,13 @@ class NonTensorData:
1246
1249
# and all the overhead falls back on this class.
1247
1250
data : Any
1248
1251
1252
+ @classmethod
1253
+ def from_tensor (cls , value : torch .Tensor , batch_size , device = None , names = None ):
1254
+ """A util to create a NonTensorData containing a tensor."""
1255
+ out = cls (data = None , batch_size = batch_size , device = device , names = names )
1256
+ out ._non_tensordict ["data" ] = value
1257
+ return out
1258
+
1249
1259
def __post_init__ (self ):
1250
1260
if isinstance (self .data , NonTensorData ):
1251
1261
self .data = self .data .data
@@ -1304,7 +1314,7 @@ def __or__(self, other):
1304
1314
self .__class__ .__or__ = __or__
1305
1315
1306
1316
def empty (self , recurse = False ):
1307
- return NonTensorData (
1317
+ return type ( self ) (
1308
1318
data = self .data ,
1309
1319
batch_size = self .batch_size ,
1310
1320
names = self .names if self ._has_names () else None ,
@@ -1332,7 +1342,7 @@ def _check_equal(a, b):
1332
1342
if all (_check_equal (data .data , first .data ) for data in list_of_non_tensor [1 :]):
1333
1343
batch_size = list (first .batch_size )
1334
1344
batch_size .insert (dim , len (list_of_non_tensor ))
1335
- return NonTensorData (
1345
+ return type ( self ) (
1336
1346
data = first .data ,
1337
1347
batch_size = batch_size ,
1338
1348
names = first .names if first ._has_names () else None ,
@@ -1358,7 +1368,7 @@ def __torch_function__(
1358
1368
):
1359
1369
return NotImplemented
1360
1370
1361
- escape_conversion = func in (torch .stack ,)
1371
+ escape_conversion = func in (torch .stack , torch . ones_like , torch . zeros_like , torch . empty_like , torch . randn_like , torch . rand_like )
1362
1372
1363
1373
if kwargs is None :
1364
1374
kwargs = {}
0 commit comments