Closed
Description
Describe the bug
I'm currently using hydra
in combination with TensorDictModule
and am running into TypeErrors
when building modules because the underlying config uses container subclasses rather than the base class.
Example
My hydra
config looks like this:
# configs/config.yaml
linear:
module:
_target_: torch.nn.Linear
input_dim: 64
output_dim: 64
in_keys: [x]
out_keys: [h]
which can be parsed in python like so:
# app.py
import hydra
from omegaconf import DictConfig
from tensordict.nn import TensorDictModule
@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg: DictConfig):
linear = hydra.utils.instantiate(cfg.linear)
td_linear = TensorDictModule(linear, cfg.in_keys, cfg.out_keys)
if __name__ == "__main__":
main()
but running this script gives a ValueError
:
$ python app.py
ValueError: out_keys must be of type list, str or tuples of str.
Proposed solution
Replace all types in tensordict/nn/common.py
with their collections.abc
counterpart (which is the python recommendation). For example, change
# tensordict/nn/common.py#L928
if isinstance(in_keys, dict):
# write the kwargs and create a list instead
_in_keys = []
self._kwargs = []
for key, value in in_keys.items():
self._kwargs.append(value)
_in_keys.append(key)
in_keys = _in_keys
else:
if isinstance(in_keys, (str, tuple)):
in_keys = [in_keys]
elif not isinstance(in_keys, list):
raise ValueError(self._IN_KEY_ERR)
self._kwargs = None
to
if isinstance(in_keys, collections.abc.Mapping):
# write the kwargs and create a list instead
_in_keys = []
self._kwargs = []
for key, value in in_keys.items():
self._kwargs.append(value)
_in_keys.append(key)
in_keys = _in_keys
else:
if isinstance(in_keys, (str, tuple)):
in_keys = [in_keys]
elif not isinstance(in_keys, collections.abc.MutableSequence): # possibly even the more general `Iterable`
raise ValueError(self._IN_KEY_ERR)
self._kwargs = None
I don't think this is critical as it's not even a "bug" perse and is easy to get around, but it would be a nice QOL change. Thanks as always!