Closed
Description
Motivation
The native nn.Sequential
supports input in the form of an OrderedDict[str, Module]
to set the names of the underlying modules rather than just using ordinal numbers (as is the default when an Iterable[Module]
is supplied). This isn't currently possible with the analogous TensorDictSequential
class, but it would be nice if we could do the same.
Example
native torch
behavior
>>> W = nn.Linear(10, 1)
>>> nn.Sequential(W)
Sequential(
(0): Linear(in_features=10, out_features=1, bias=True)
)
>>> nn.Sequential(OrderedDict(fc0=W)
Sequential(
(fc0): Linear(in_features=10, out_features=1, bias=True)
)
Current tensordict
behavior
>>> W_tdm = TensorDictModule(W, ['x'], ['y'])
>>> TensorDictSequential(W)
TensorDictSequential(
module=ModuleList(
(0): TensorDictModule(
module=Linear(in_features=8, out_features=16, bias=True),
device=cpu,
in_keys=['x'],
out_keys=['linear0'])
),
device=cpu,
in_keys=['x'],
out_keys=['y'])
>>> TensorDictSequential(OrderedDict(fc0=W))
TensorDictSequential(
module=ModuleList(
(0): WrapModule()
),
device=cpu,
in_keys=[],
out_keys=[])
Desired behavior
>>> TensorDictSequential(OrderedDict(fc0=W))
TensorDictSequential(
module=ModuleList(
(fc0): TensorDictModule(
module=Linear(in_features=8, out_features=16, bias=True),
device=cpu,
in_keys=['x'],
out_keys=['linear0'])
),
device=cpu,
in_keys=['x'],
out_keys=['linear0'])