Skip to content

Commit 92b52a0

Browse files
louisfauryLouis Faury
andauthored
[BugFix] ActionMask is compatible with composite action specs (#3022)
Co-authored-by: Louis Faury <[email protected]>
1 parent 5a13341 commit 92b52a0

File tree

1 file changed

+11
-27
lines changed

1 file changed

+11
-27
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8726,8 +8726,10 @@ def set_container(self, container):
87268726
class ActionMask(Transform):
87278727
"""An adaptive action masker.
87288728
8729-
This transform reads the mask from the input tensordict after the step is executed,
8730-
and adapts the mask of the one-hot / categorical action spec.
8729+
This transform is useful to ensure that randomly generated actions
8730+
respect legal actions, by masking the action specs.
8731+
It reads the mask from the input tensordict after the step is executed,
8732+
and adapts the mask of the finite action spec.
87318733
87328734
.. note:: This transform will fail when used without an environment.
87338735
@@ -8773,8 +8775,6 @@ class ActionMask(Transform):
87738775
>>> base_env = MaskedEnv()
87748776
>>> env = TransformedEnv(base_env, ActionMask())
87758777
>>> r = env.rollout(10)
8776-
>>> env = TransformedEnv(base_env, ActionMask())
8777-
>>> r = env.rollout(10)
87788778
>>> r["action_mask"]
87798779
tensor([[ True, True, True, True],
87808780
[ True, True, False, True],
@@ -8810,45 +8810,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
88108810
raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self)))
88118811

88128812
@property
8813-
def action_spec(self):
8814-
action_spec = self.container.full_action_spec
8815-
keys = self.container.action_keys
8816-
if len(keys) == 1:
8817-
action_spec = action_spec[keys[0]]
8818-
else:
8819-
raise ValueError(
8820-
f"Too many action keys for {self.__class__.__name__}: {keys=}"
8821-
)
8813+
def action_spec(self) -> TensorSpec:
8814+
action_spec = self.container.full_action_spec[self.in_keys[0]]
88228815
if not isinstance(action_spec, self.ACCEPTED_SPECS):
88238816
raise ValueError(
88248817
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
88258818
)
88268819
return action_spec
88278820

88288821
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
8829-
parent = self.parent
8830-
if parent is None:
8822+
if self.parent is None:
88318823
raise RuntimeError(
88328824
f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment."
88338825
)
8826+
88348827
mask = next_tensordict.get(self.in_keys[1])
8835-
action_spec = self.action_spec
8836-
action_spec.update_mask(mask.to(action_spec.device))
8828+
self.action_spec.update_mask(mask.to(self.action_spec.device))
8829+
88378830
return next_tensordict
88388831

88398832
def _reset(
88408833
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
88418834
) -> TensorDictBase:
8842-
action_spec = self.action_spec
8843-
mask = tensordict.get(self.in_keys[1], None)
8844-
if mask is not None:
8845-
mask = mask.to(action_spec.device)
8846-
action_spec.update_mask(mask)
8847-
8848-
# TODO: Check that this makes sense
8849-
with _set_missing_tolerance(self, True):
8850-
tensordict_reset = self._call(tensordict_reset)
8851-
return tensordict_reset
8835+
return self._call(tensordict_reset)
88528836

88538837

88548838
class VecGymEnvTransform(Transform):

0 commit comments

Comments
 (0)