Skip to content

Commit 2aea3dd

Browse files
author
Vincent Moens
committed
[Refactor] Better compile checks
ghstack-source-id: c6a8d45 Pull Request resolved: #1139
1 parent eb4a56e commit 2aea3dd

File tree

6 files changed

+46
-49
lines changed

6 files changed

+46
-49
lines changed

tensordict/_td.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@
9898

9999
_has_funcdim = False
100100
try:
101-
from torch.compiler import is_dynamo_compiling
101+
from torch.compiler import is_compiling
102102
except ImportError: # torch 2.0
103-
from torch._dynamo import is_compiling as is_dynamo_compiling
103+
from torch._dynamo import is_compiling
104104

105105
try:
106106
from torch.nn.parameter import Buffer
@@ -251,7 +251,7 @@ def __init__(
251251

252252
self._tensordict = _StringOnlyDict()
253253

254-
# if names and is_dynamo_compiling():
254+
# if names and is_compiling():
255255
# graph_break()
256256
has_device = device is not None
257257
sub_non_blocking = False
@@ -284,7 +284,7 @@ def __init__(
284284
)
285285
self._batch_size = self._parse_batch_size(source, batch_size)
286286
# TODO: this breaks when stacking tensorclasses with dynamo
287-
if not is_dynamo_compiling():
287+
if not is_compiling():
288288
self.names = names
289289

290290
for key, value in source.items():
@@ -313,7 +313,7 @@ def _new_unsafe(
313313
nested: bool = True,
314314
**kwargs: dict[str, Any] | None,
315315
) -> TensorDict:
316-
if is_dynamo_compiling():
316+
if is_compiling():
317317
return TensorDict(
318318
source,
319319
batch_size=batch_size,
@@ -473,7 +473,7 @@ def _to_module(
473473
is_dynamo: bool | None = None,
474474
):
475475
if is_dynamo is None:
476-
is_dynamo = is_dynamo_compiling()
476+
is_dynamo = is_compiling()
477477
if is_dynamo:
478478
_check_inbuild()
479479

@@ -2264,7 +2264,7 @@ def _parse_batch_size(
22642264
) -> torch.Size:
22652265
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."
22662266

2267-
if is_dynamo_compiling():
2267+
if is_compiling():
22682268
if isinstance(batch_size, torch.Size):
22692269
return batch_size
22702270
elif isinstance(batch_size, tuple):
@@ -2316,7 +2316,7 @@ def names(self):
23162316

23172317
@names.setter
23182318
def names(self, value):
2319-
if is_dynamo_compiling():
2319+
if is_compiling():
23202320
if value is not None:
23212321
graph_break()
23222322
else:

tensordict/_torch_func.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
)
4040

4141
try:
42-
from torch.compiler import is_dynamo_compiling
42+
from torch.compiler import is_compiling
4343
except ImportError: # torch 2.0
44-
from torch._dynamo import is_compiling as is_dynamo_compiling
44+
from torch._dynamo import is_compiling
4545

4646
TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
4747
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
@@ -301,7 +301,7 @@ def _cat(
301301
out = {}
302302
for key in keys:
303303
items = [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts]
304-
if not is_dynamo_compiling():
304+
if not is_compiling():
305305
with _ErrorInteceptor(
306306
key, "Attempted to concatenate tensors on different devices at key"
307307
):
@@ -335,7 +335,7 @@ def _cat(
335335
_ErrorInteceptor(
336336
key, "Attempted to concatenate tensors on different devices at key"
337337
)
338-
if not is_dynamo_compiling()
338+
if not is_compiling()
339339
else contextlib.nullcontext()
340340
):
341341
if isinstance(out, TensorDict):
@@ -592,7 +592,7 @@ def stack_fn(key, values, is_not_init, is_tensor):
592592
_ErrorInteceptor(
593593
key, "Attempted to stack tensors on different devices at key"
594594
)
595-
if not is_dynamo_compiling()
595+
if not is_compiling()
596596
else contextlib.nullcontext()
597597
):
598598
return _stack(values, dim, maybe_dense_stack=maybe_dense_stack)

tensordict/base.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@
103103
from torch.utils._pytree import tree_map
104104

105105
try:
106-
from torch.compiler import is_dynamo_compiling
106+
from torch.compiler import is_compiling
107107
except ImportError: # torch 2.0
108-
from torch._dynamo import is_compiling as is_dynamo_compiling
108+
from torch._dynamo import is_compiling
109109

110110
try:
111111
from torch import _foreach_copy_
@@ -5247,7 +5247,7 @@ def _view_and_pad(tensor):
52475247
if v.device != storage.device:
52485248
v = v.to(storage.device, non_blocking=non_blocking)
52495249
stride = v.stride()
5250-
if is_dynamo_compiling():
5250+
if is_compiling():
52515251
if not v.is_contiguous():
52525252
v = v.clone(memory_format=torch.contiguous_format)
52535253
elif (stride and stride[-1] != 1) or v.storage_offset():
@@ -6963,7 +6963,7 @@ def _values_list(
69636963
is_leaf=is_leaf,
69646964
collapse=collapse,
69656965
)
6966-
if is_dynamo_compiling():
6966+
if is_compiling():
69676967
key_to_index = {key: i for i, key in enumerate(keys)}
69686968
return [vals[key_to_index[key]] for key in sorting_keys]
69696969
else:
@@ -6994,7 +6994,7 @@ def _items_list(
69946994
return list(keys), list(vals)
69956995
if default is None:
69966996
# TODO: check that lists are identical
6997-
if is_dynamo_compiling():
6997+
if is_compiling():
69986998
key_to_index = {key: i for i, key in enumerate(keys)}
69996999
new_vals = [vals[key_to_index[key]] for key in sorting_keys]
70007000
if len(new_vals) < len(vals):
@@ -7015,12 +7015,9 @@ def _items_list(
70157015
] # intersection does not keep the sorting
70167016
else:
70177017
new_keys = list(set(sorting_keys).union(keys))
7018-
if is_dynamo_compiling():
7019-
...
7020-
else:
7021-
source = dict(zip(keys, vals))
7022-
vals = [source.get(key, default) for key in new_keys]
7023-
return new_keys, vals
7018+
source = dict(zip(keys, vals))
7019+
vals = [source.get(key, default) for key in new_keys]
7020+
return new_keys, vals
70247021

70257022
def _grad(self):
70267023
# We can't cache this because zero_grad can be called outside (eg from optimizer) and we want the tensors
@@ -11931,7 +11928,7 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
1193111928
result.lock_()
1193211929
return result
1193311930
else:
11934-
if not is_dynamo_compiling():
11931+
if not is_compiling():
1193511932
key_list = list(self.keys())
1193611933
else:
1193711934
key_list = [k for k in self.keys()] # noqa
@@ -12196,10 +12193,10 @@ def lock_(self) -> T:
1219612193
"""
1219712194
if self.is_locked:
1219812195
return self
12199-
is_compiling = is_dynamo_compiling()
12200-
if is_compiling:
12196+
is_comp = is_compiling()
12197+
if is_comp:
1220112198
_lock_warn()
12202-
self._propagate_lock(is_compiling=is_compiling)
12199+
self._propagate_lock(is_compiling=is_comp)
1220312200
return self
1220412201

1220512202
@erase_cache
@@ -12611,7 +12608,7 @@ def copy_dict(d):
1261112608
def _sync_all(self):
1261212609
if _has_cuda:
1261312610
# TODO: dynamo doesn't like torch.cuda.is_initialized
12614-
if not is_dynamo_compiling() and torch.cuda.is_initialized():
12611+
if not is_compiling() and torch.cuda.is_initialized():
1261512612
torch.cuda.synchronize()
1261612613
elif _has_mps:
1261712614
mps = getattr(torch, "mps", None)
@@ -12799,7 +12796,7 @@ def _register_tensor_class(cls):
1279912796

1280012797

1280112798
def _is_tensor_collection(datatype: type) -> bool:
12802-
is_dynamo = is_dynamo_compiling()
12799+
is_dynamo = is_compiling()
1280312800
out = None
1280412801
if not is_dynamo:
1280512802
out = _TENSOR_COLLECTION_MEMO.get(datatype)

tensordict/nn/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from torch import nn, Tensor
2828

2929
try:
30-
from torch.compiler import is_dynamo_compiling
30+
from torch.compiler import is_compiling
3131
except ImportError: # torch 2.0
32-
from torch._dynamo import is_compiling as is_dynamo_compiling
32+
from torch._dynamo import is_compiling
3333

3434
try:
3535
from functorch import FunctionalModule, FunctionalModuleWithBuffers
@@ -1153,7 +1153,7 @@ def __repr__(self) -> str:
11531153
return f"{type(self).__name__}(\n{fields})"
11541154

11551155
def __getattr__(self, name: str) -> Any:
1156-
if not is_dynamo_compiling():
1156+
if not is_compiling():
11571157
__dict__ = self.__dict__
11581158
_parameters = __dict__.get("_parameters")
11591159
if _parameters:
@@ -1230,7 +1230,7 @@ def __init__(self, td_module: TensorDictModuleBase) -> None:
12301230
self.register_forward_hook(self.td_module._forward_hooks[pre_hook])
12311231

12321232
def __getattr__(self, name: str) -> Any:
1233-
if not is_dynamo_compiling():
1233+
if not is_compiling():
12341234
__dict__ = self.__dict__
12351235
_parameters = __dict__.get("_parameters")
12361236
if _parameters:

tensordict/nn/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from torch.utils._contextlib import _DecoratorContextManager
1919

2020
try:
21-
from torch.compiler import is_dynamo_compiling
21+
from torch.compiler import is_compiling
2222
except ImportError: # torch 2.0
23-
from torch._dynamo import is_compiling as is_dynamo_compiling
23+
from torch._dynamo import is_compiling
2424

2525

2626
_dispatch_tdnn_modules = _ContextManager(
@@ -300,7 +300,7 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
300300
return super().__call__(wrapper)
301301

302302
def __enter__(self) -> None:
303-
if self.mode and is_dynamo_compiling():
303+
if self.mode and is_compiling():
304304
raise RuntimeError("skip_existing is not compatible with TorchDynamo.")
305305
self.prev = _skip_existing.get_mode()
306306
if self.mode is not None:
@@ -338,7 +338,7 @@ def __call__(self, func: Callable):
338338

339339
@functools.wraps(func)
340340
def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
341-
if skip_existing() and is_dynamo_compiling():
341+
if skip_existing() and is_compiling():
342342
raise RuntimeError(
343343
"skip_existing is not compatible with torch.compile."
344344
)
@@ -351,7 +351,7 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
351351
and not any(key in out_keys for key in in_keys)
352352
):
353353
return tensordict
354-
if is_dynamo_compiling():
354+
if is_compiling():
355355
return func(_self, tensordict, *args, **kwargs)
356356
self.prev = _skip_existing.get_mode()
357357
try:

tensordict/tensorclass.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
from torch.utils._pytree import tree_map
6565

6666
try:
67-
from torch.compiler import is_dynamo_compiling
67+
from torch.compiler import is_compiling
6868
except ImportError: # torch 2.0
69-
from torch._dynamo import is_compiling as is_dynamo_compiling
69+
from torch._dynamo import is_compiling
7070

7171

7272
def _identity(cls):
@@ -890,7 +890,7 @@ def wrapper(
890890
if lock is None:
891891
lock = frozen
892892

893-
if not is_dynamo_compiling():
893+
if not is_compiling():
894894
# zip not supported by dynamo
895895
for value, key in zip(args, self.__dataclass_fields__):
896896
if key in kwargs:
@@ -904,7 +904,7 @@ def wrapper(
904904

905905
if batch_size is None:
906906
batch_size = torch.Size([])
907-
if not is_dynamo_compiling():
907+
if not is_compiling():
908908
for key, field in type(self).__dataclass_fields__.items():
909909
if field.default_factory is not dataclasses.MISSING:
910910
default = field.default_factory()
@@ -1072,7 +1072,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa:
10721072
# tensordict = tensordict.copy()
10731073
tensor_keys = tensordict.keys()
10741074
# TODO: compile doesn't like set() over an arbitrary object
1075-
if is_dynamo_compiling():
1075+
if is_compiling():
10761076
tensor_keys = {k for k in tensor_keys} # noqa: C416
10771077
exp_keys = {k for k in cls.__expected_keys__} # noqa: C416
10781078
if non_tensordict is not None:
@@ -1112,7 +1112,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa:
11121112
for key in to_add:
11131113
non_tensordict[key] = None
11141114

1115-
if not is_dynamo_compiling():
1115+
if not is_compiling():
11161116
# bypass initialisation. this means we don't incur any overhead creating an
11171117
# empty tensordict and writing values to it. we can skip this because we already
11181118
# have a tensordict to use as the underlying tensordict
@@ -1313,7 +1313,7 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
13131313
value (any): the value to set for the attribute
13141314
13151315
"""
1316-
if not is_dynamo_compiling():
1316+
if not is_compiling():
13171317
__dict__ = self.__dict__
13181318
if (
13191319
"_tensordict" not in __dict__
@@ -1348,7 +1348,7 @@ def deliver_result(self, result, kwargs):
13481348
if result is None:
13491349
return
13501350
if isinstance(result, TensorDictBase) and kwargs.get("out") is not result:
1351-
if not is_dynamo_compiling():
1351+
if not is_compiling():
13521352
non_tensordict = super(type(self), self).__getattribute__(
13531353
"_non_tensordict"
13541354
)
@@ -1362,7 +1362,7 @@ def deliver_result(self, result, kwargs):
13621362
return result
13631363

13641364
def wrapped_func(self, *args, **kwargs):
1365-
if not is_dynamo_compiling():
1365+
if not is_compiling():
13661366
td = super(type(self), self).__getattribute__("_tensordict")
13671367
else:
13681368
td = self._tensordict
@@ -1409,7 +1409,7 @@ def wrapped_func(*args, **kwargs):
14091409
return type(self)._from_tensordict(res, dict(self._non_tensordict))
14101410
return res
14111411

1412-
if not is_dynamo_compiling():
1412+
if not is_compiling():
14131413
wrapped_func = functools.wraps(func)(wrapped_func)
14141414

14151415
return wrapped_func

0 commit comments

Comments
 (0)