diff --git a/tensordict/_pytree.py b/tensordict/_pytree.py index af8d6d0c2..0475d7865 100644 --- a/tensordict/_pytree.py +++ b/tensordict/_pytree.py @@ -12,13 +12,9 @@ from tensordict.persistent import PersistentTensorDict from tensordict.utils import _shape, implement_for -try: - from torch.utils._pytree import Context, MappingKey, register_pytree_node -except ImportError: - from torch.utils._pytree import ( - _register_pytree_node as register_pytree_node, - Context, - ) +from torch.utils._pytree import Context, MappingKey, register_pytree_node +from torch.utils._cxx_pytree import register_pytree_node as register_pytree_node_cxx + PYTREE_REGISTERED_TDS = ( _SubTensorDict, @@ -83,7 +79,7 @@ def _str_to_tensordictdict(str_spec: str) -> Tuple[List[str], str]: def _tensordict_flatten(d: TensorDict) -> Tuple[List[Any], Context]: - items = tuple(d.items()) + items = tuple(d._tensordict.items()) if items: keys, values = zip(*items) keys = list(keys) @@ -186,11 +182,17 @@ def _register_td_node(cls): @implement_for("torch", "2.3") def _register_td_node(cls): # noqa: F811 - register_pytree_node( + # register_pytree_node( + # cls, + # _tensordict_flatten, + # _tensordict_unflatten, + # flatten_with_keys_fn=_td_flatten_with_keys, + # ) + register_pytree_node_cxx( cls, _tensordict_flatten, _tensordict_unflatten, - flatten_with_keys_fn=_td_flatten_with_keys, + # flatten_with_keys_fn=_td_flatten_with_keys, ) @@ -205,11 +207,17 @@ def _register_lazy_td_node(cls): @implement_for("torch", "2.3") def _register_lazy_td_node(cls): # noqa: F811 - register_pytree_node( + # register_pytree_node( + # cls, + # _lazy_tensordict_flatten, + # _lazy_tensordict_unflatten, + # flatten_with_keys_fn=_lazy_td_flatten_with_keys, + # ) + register_pytree_node_cxx( cls, _lazy_tensordict_flatten, _lazy_tensordict_unflatten, - flatten_with_keys_fn=_lazy_td_flatten_with_keys, + # flatten_with_keys_fn=_lazy_td_flatten_with_keys, )