From fabb9664e7dba5e6787a3f1ff87b160820ec17b1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 15 May 2025 07:55:11 +0100 Subject: [PATCH] [Refactor] Pass all keys at reset (prototype) --- test/test_env.py | 14 ++++++++------ torchrl/envs/common.py | 5 +++-- torchrl/envs/custom/chess.py | 9 ++++++--- torchrl/envs/transforms/transforms.py | 6 ++---- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 48509aff4bf..c0142c630a3 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4083,7 +4083,7 @@ def test_reset_white_to_move(self, stateful, include_pgn, include_fen): stateful=stateful, include_pgn=include_pgn, include_fen=include_fen ) fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1" - td = env.reset(TensorDict({"fen": fen})) + td = env.reset(TensorDict({"fen_reset": fen})) if include_fen: assert td["fen"] == fen assert env.board.fen() == fen @@ -4097,7 +4097,7 @@ def test_reset_black_to_move(self, stateful, include_pgn, include_fen): stateful=stateful, include_pgn=include_pgn, include_fen=include_fen ) fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1" - td = env.reset(TensorDict({"fen": fen})) + td = env.reset(TensorDict({"fen_reset": fen})) assert td["fen"] == fen assert env.board.fen() == fen assert td["turn"] == env.lib.BLACK @@ -4111,7 +4111,7 @@ def test_reset_done_error(self, stateful, include_pgn, include_fen): ) fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1" with pytest.raises(ValueError) as e_info: - env.reset(TensorDict({"fen": fen})) + env.reset(TensorDict({"fen_reset": fen})) assert "Cannot reset to a fen that is a gameover state" in str(e_info) @@ -4181,7 +4181,7 @@ def test_reward( if reset_without_fen: td = TensorDict({"fen": fen}) else: - td = env.reset(TensorDict({"fen": fen})) + td = env.reset(TensorDict({"fen_reset": fen})) assert td["turn"] == expected_turn td["action"] = env._san_moves.index(move) @@ -4230,16 +4230,18 @@ def test_env_reset_with_hash(self, stateful, include_san): ] for fen, num_legal_moves in cases: # Load the state by fen. - td = env.reset(TensorDict({"fen": fen})) + td = env.reset(TensorDict({"fen_reset": fen})) assert td["fen"] == fen assert td["action_mask"].sum() == num_legal_moves + # Reset to initial state just to make sure that the next reset # actually changes the state. assert env.reset()["action_mask"].sum() == 20 + # Load the state by fen hash and make sure it gives the same output # as before. td_check = env.reset(td.select("fen_hash")) - assert (td_check == td).all() + assert assert_allclose_td(td_check, td, intersection=True) @pytest.mark.parametrize("include_fen", [False, True]) @pytest.mark.parametrize("include_pgn", [False, True]) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b9b120dc9f5..5d4f74497e7 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2782,9 +2782,10 @@ def reset( # Therefore, maybe_reset tells reset to temporarily hide the non-reset keys. # To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used. tensordict_reset = self._reset( - tensordict.select(*self.reset_keys, strict=False), **kwargs + tensordict.exclude(*self.state_keys), **kwargs ) else: + print('tensordict', tensordict) tensordict_reset = self._reset(tensordict, **kwargs) # We assume that this is done properly # if reset.device != self.device: @@ -3634,7 +3635,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ any_done = self.any_done(tensordict) if any_done: - tensordict = self.reset(tensordict, select_reset_only=True) + tensordict = self.reset(tensordict) return tensordict def empty_cache(self): diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 742519f2fec..b268bed5cc3 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -93,6 +93,9 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta): The action space is structured as a categorical distribution over all possible SAN moves, with the legal moves being a subset of this space. The environment uses a mask to ensure only legal moves are selected. + .. note:: You can reset the env at a given state by passing `"fen_reset"` or `"pgn_reset"` to the TensorDict passed + to the reset method. + Examples: >>> import torch >>> from torchrl.envs import ChessEnv @@ -322,7 +325,7 @@ def __init__( self.stateful = stateful # state_spec is loosely defined as such - it's not really an issue that extra keys - # can go missing but it allows us to reset the env using fen passed to the reset + # can go missing, but it allows us to reset the env using fen passed to the reset # method. self.full_state_spec = self.full_observation_spec.clone() @@ -374,11 +377,11 @@ def _reset(self, tensordict=None): if tensordict is not None: dest = tensordict.empty() if self.include_fen: - fen = tensordict.get("fen", None) + fen = tensordict.get("fen_reset", None) if fen is not None: fen = fen.data elif self.include_pgn: - pgn = tensordict.get("pgn", None) + pgn = tensordict.get("pgn_reset", None) if pgn is not None: pgn = pgn.data else: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 19e2ad7ec7d..01a16934c03 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1173,10 +1173,7 @@ def _set_seed(self, seed: int | None) -> None: def _reset(self, tensordict: TensorDictBase | None = None, **kwargs): if tensordict is not None: # We must avoid modifying the original tensordict so a shallow copy is necessary. - # We just select the input data and reset signal, which is all we need. - tensordict = tensordict.select( - *self.reset_keys, *self.state_spec.keys(True, True), strict=False - ) + tensordict = tensordict.copy() # We always call _reset_env_preprocess, even if tensordict is None - that way one can augment that # method to do any pre-reset operation. # By default, within _reset_env_preprocess we will skip the inv call when tensordict is None. @@ -7225,6 +7222,7 @@ def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: """Resets episode rewards.""" + print(f'{tensordict=}') for in_key, reset_key, out_key in _zip_strict( self.in_keys, self.reset_keys, self.out_keys ):