@@ -8726,8 +8726,10 @@ def set_container(self, container):
8726
8726
class ActionMask (Transform ):
8727
8727
"""An adaptive action masker.
8728
8728
8729
- This transform reads the mask from the input tensordict after the step is executed,
8730
- and adapts the mask of the one-hot / categorical action spec.
8729
+ This transform is useful to ensure that randomly generated actions
8730
+ respect legal actions, by masking the action specs.
8731
+ It reads the mask from the input tensordict after the step is executed,
8732
+ and adapts the mask of the finite action spec.
8731
8733
8732
8734
.. note:: This transform will fail when used without an environment.
8733
8735
@@ -8773,8 +8775,6 @@ class ActionMask(Transform):
8773
8775
>>> base_env = MaskedEnv()
8774
8776
>>> env = TransformedEnv(base_env, ActionMask())
8775
8777
>>> r = env.rollout(10)
8776
- >>> env = TransformedEnv(base_env, ActionMask())
8777
- >>> r = env.rollout(10)
8778
8778
>>> r["action_mask"]
8779
8779
tensor([[ True, True, True, True],
8780
8780
[ True, True, False, True],
@@ -8810,45 +8810,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
8810
8810
raise RuntimeError (FORWARD_NOT_IMPLEMENTED .format (type (self )))
8811
8811
8812
8812
@property
8813
- def action_spec (self ):
8814
- action_spec = self .container .full_action_spec
8815
- keys = self .container .action_keys
8816
- if len (keys ) == 1 :
8817
- action_spec = action_spec [keys [0 ]]
8818
- else :
8819
- raise ValueError (
8820
- f"Too many action keys for { self .__class__ .__name__ } : { keys = } "
8821
- )
8813
+ def action_spec (self ) -> TensorSpec :
8814
+ action_spec = self .container .full_action_spec [self .in_keys [0 ]]
8822
8815
if not isinstance (action_spec , self .ACCEPTED_SPECS ):
8823
8816
raise ValueError (
8824
8817
self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
8825
8818
)
8826
8819
return action_spec
8827
8820
8828
8821
def _call (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
8829
- parent = self .parent
8830
- if parent is None :
8822
+ if self .parent is None :
8831
8823
raise RuntimeError (
8832
8824
f"{ type (self )} .parent cannot be None: make sure this transform is executed within an environment."
8833
8825
)
8826
+
8834
8827
mask = next_tensordict .get (self .in_keys [1 ])
8835
- action_spec = self .action_spec
8836
- action_spec . update_mask ( mask . to ( action_spec . device ))
8828
+ self . action_spec . update_mask ( mask . to ( self .action_spec . device ))
8829
+
8837
8830
return next_tensordict
8838
8831
8839
8832
def _reset (
8840
8833
self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
8841
8834
) -> TensorDictBase :
8842
- action_spec = self .action_spec
8843
- mask = tensordict .get (self .in_keys [1 ], None )
8844
- if mask is not None :
8845
- mask = mask .to (action_spec .device )
8846
- action_spec .update_mask (mask )
8847
-
8848
- # TODO: Check that this makes sense
8849
- with _set_missing_tolerance (self , True ):
8850
- tensordict_reset = self ._call (tensordict_reset )
8851
- return tensordict_reset
8835
+ return self ._call (tensordict_reset )
8852
8836
8853
8837
8854
8838
class VecGymEnvTransform (Transform ):
0 commit comments