diff --git a/carl/envs/__init__.py b/carl/envs/__init__.py index b03f906c..a51850d7 100644 --- a/carl/envs/__init__.py +++ b/carl/envs/__init__.py @@ -64,3 +64,10 @@ def check_spec(spec_name: str) -> bool: from carl.envs.rna import * __all__ += envs.rna.__all__ + +gymnax_spec = iutil.find_spec("gymnax") +found = gymnax_spec is not None +if found: + from carl.envs.gymnax import * + + __all__ += envs.gymnax.__all__ diff --git a/carl/envs/carl_env.py b/carl/envs/carl_env.py index 00105c88..af20a924 100644 --- a/carl/envs/carl_env.py +++ b/carl/envs/carl_env.py @@ -178,7 +178,6 @@ def get_observation_space( context_feature_names=obs_context_feature_names, as_dict=self.obs_context_as_dict, ) - obs_space = spaces.Dict( { "obs": self.base_observation_space, diff --git a/carl/envs/gymnasium/classic_control/carl_pendulum.py b/carl/envs/gymnasium/classic_control/carl_pendulum.py index 10226dd8..7ab92f0a 100644 --- a/carl/envs/gymnasium/classic_control/carl_pendulum.py +++ b/carl/envs/gymnasium/classic_control/carl_pendulum.py @@ -15,9 +15,6 @@ class CARLPendulum(CARLGymnasiumEnv): @staticmethod def get_context_features() -> dict[str, ContextFeature]: return { - "gravity": UniformFloatContextFeature( - "gravity", lower=-np.inf, upper=np.inf, default_value=8.0 - ), "dt": UniformFloatContextFeature( "dt", lower=0, upper=np.inf, default_value=0.05 ), diff --git a/carl/envs/gymnax/__init__.py b/carl/envs/gymnax/__init__.py new file mode 100644 index 00000000..8a574589 --- /dev/null +++ b/carl/envs/gymnax/__init__.py @@ -0,0 +1,15 @@ +from carl.envs.gymnax.classic_control import ( + CARLGymnaxAcrobot, + CARLGymnaxCartPole, + CARLGymnaxMountainCar, + CARLGymnaxMountainCarContinuous, + CARLGymnaxPendulum, +) + +__all__ = [ + "CARLGymnaxAcrobot", + "CARLGymnaxCartPole", + "CARLGymnaxMountainCar", + "CARLGymnaxMountainCarContinuous", + "CARLGymnaxPendulum", +] diff --git a/carl/envs/gymnax/carl_gymnax_env.py b/carl/envs/gymnax/carl_gymnax_env.py new file mode 100644 index 00000000..84db31f3 --- /dev/null +++ b/carl/envs/gymnax/carl_gymnax_env.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Any + +import importlib + +from gymnasium.core import Env + +from carl.context.selection import AbstractSelector +from carl.envs.carl_env import CARLEnv +from carl.envs.gymnax.utils import make_gymnax_env +from carl.utils.types import Contexts + + +class CARLGymnaxEnv(CARLEnv): + env_name: str + + def __init__( + self, + env: Env | None = None, + contexts: Contexts | None = None, + obs_context_features: list[str] + | None = None, # list the context features which should be added to the state + obs_context_as_dict: bool = True, + context_selector: AbstractSelector | type[AbstractSelector] | None = None, + context_selector_kwargs: dict = None, + **kwargs, + ) -> None: + """ + CARL Gymnax Environment. + + Parameters + ---------- + + env : Env | None + Gymnasium environment, the default is None. + If None, instantiate the env with gymnasium's make function and + `self.env_name` which is defined in each child class. + contexts : Contexts | None, optional + Context set, by default None. If it is None, we build the + context set with the default context. + obs_context_features : list[str] | None, optional + Context features which should be included in the observation, by default None. + If they are None, add all context features. + context_selector: AbstractSelector | type[AbstractSelector] | None, optional + The context selector (class), after each reset selects a new context to use. + If None, use a round robin selector. + context_selector_kwargs : dict, optional + Optional keyword arguments for the context selector, by default None. + Only used when `context_selector` is not None. + + Attributes + ---------- + env_name: str + The registered gymnax environment name. + """ + if env is None: + env = make_gymnax_env(env_name=self.env_name) + + super().__init__( + env=env, + contexts=contexts, + obs_context_features=obs_context_features, + obs_context_as_dict=obs_context_as_dict, + context_selector=context_selector, + context_selector_kwargs=context_selector_kwargs, + **kwargs, + ) + + def __getattr__(self, name: str) -> Any: + if name in ["sys", "__getstate__"]: + return getattr(self.env._environment, name) + else: + return getattr(self, name) + + def _update_context(self) -> None: + content = self.env.env_params.__dict__ + content.update(self.context) + # We cannot directly set attributes of env_params because it is a frozen dataclass + + # TODO Make this faster by preloading module? + self.env.env.env_params = getattr( + importlib.import_module(f"gymnax.environments.{self.module}"), "EnvParams" + )(**content) diff --git a/carl/envs/gymnax/classic_control/__init__.py b/carl/envs/gymnax/classic_control/__init__.py new file mode 100644 index 00000000..5450542d --- /dev/null +++ b/carl/envs/gymnax/classic_control/__init__.py @@ -0,0 +1,15 @@ +from carl.envs.gymnax.classic_control.carl_gymnax_acrobot import CARLGymnaxAcrobot +from carl.envs.gymnax.classic_control.carl_gymnax_cartpole import CARLGymnaxCartPole +from carl.envs.gymnax.classic_control.carl_gymnax_mountaincar import ( + CARLGymnaxMountainCar, + CARLGymnaxMountainCarContinuous, +) +from carl.envs.gymnax.classic_control.carl_gymnax_pendulum import CARLGymnaxPendulum + +__all__ = [ + "CARLGymnaxAcrobot", + "CARLGymnaxCartPole", + "CARLGymnaxMountainCar", + "CARLGymnaxMountainCarContinuous", + "CARLGymnaxPendulum", +] diff --git a/carl/envs/gymnax/classic_control/carl_gymnax_acrobot.py b/carl/envs/gymnax/classic_control/carl_gymnax_acrobot.py new file mode 100644 index 00000000..e28b4a7e --- /dev/null +++ b/carl/envs/gymnax/classic_control/carl_gymnax_acrobot.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import jax.numpy as jnp +import numpy as np + +from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv + + +class CARLGymnaxAcrobot(CARLGymnaxEnv): + env_name: str = "Acrobot-v1" + module: str = "classic_control.acrobot" + + @staticmethod + def get_context_features() -> dict[str, ContextFeature]: + return { + "link_length_1": UniformFloatContextFeature( + "link_length_1", lower=0.1, upper=10, default_value=1 + ), # Links can be shrunken and grown by a factor of 10 + "link_length_2": UniformFloatContextFeature( + "link_length_2", lower=0.1, upper=10, default_value=1 + ), # Links can be shrunken and grown by a factor of 10 + "link_mass_1": UniformFloatContextFeature( + "link_mass_1", lower=0.1, upper=10, default_value=1 + ), # Link mass can be shrunken and grown by a factor of 10 + "link_mass_2": UniformFloatContextFeature( + "link_mass_2", lower=0.1, upper=10, default_value=1 + ), # Link mass can be shrunken and grown by a factor of 10 + "link_com_pos_1": UniformFloatContextFeature( + "link_com_pos_1", lower=0, upper=1, default_value=0.5 + ), # Center of mass can move from one end to the other + "link_com_pos_2": UniformFloatContextFeature( + "link_com_pos_2", lower=0, upper=1, default_value=0.5 + ), # Center of mass can move from one end to the other + "link_moi": UniformFloatContextFeature( + "link_moi", lower=0.1, upper=10, default_value=1 + ), # Moments on inertia can be shrunken and grown by a factor of 10 + "max_vel_1": UniformFloatContextFeature( + "max_vel_1", + lower=0.4 * jnp.pi, + upper=40 * jnp.pi, + default_value=4 * jnp.pi, + ), # Velocity can vary by a factor of 10 in either direction + "max_vel_2": UniformFloatContextFeature( + "max_vel_2", + lower=0.9 * np.pi, + upper=90 * np.pi, + default_value=9 * np.pi, + ), # Velocity can vary by a factor of 10 in either direction + "torque_noise_max": UniformFloatContextFeature( + "torque_noise_max", lower=-1, upper=1, default_value=0 + ), # torque is either {-1., 0., 1}. Applying noise of 1. would be quite extreme + } diff --git a/carl/envs/gymnax/classic_control/carl_gymnax_cartpole.py b/carl/envs/gymnax/classic_control/carl_gymnax_cartpole.py new file mode 100644 index 00000000..5e3df165 --- /dev/null +++ b/carl/envs/gymnax/classic_control/carl_gymnax_cartpole.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import importlib + +from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv + + +class CARLGymnaxCartPole(CARLGymnaxEnv): + env_name: str = "CartPole-v1" + module: str = "classic_control.cartpole" + + @staticmethod + def get_context_features() -> dict[str, ContextFeature]: + return { + "gravity": UniformFloatContextFeature( + "gravity", lower=0.01, upper=100, default_value=9.8 + ), + "masscart": UniformFloatContextFeature( + "masscart", lower=0.1, upper=10, default_value=1.0 + ), + "masspole": UniformFloatContextFeature( + "masspole", lower=0.01, upper=1, default_value=0.1 + ), + "length": UniformFloatContextFeature( + "length", lower=0.05, upper=5, default_value=0.5 + ), + "force_mag": UniformFloatContextFeature( + "force_mag", lower=1, upper=100, default_value=10.0 + ), + "tau": UniformFloatContextFeature( + "tau", lower=0.002, upper=0.2, default_value=0.02 + ), + } + + def _update_context(self) -> None: + content = self.env.env_params.__dict__ + content.update(self.context) + content["total_mass"] = content["masspole"] + content["masscart"] + content["polemass_length"] = content["masspole"] * content["length"] + + # TODO Make this faster by preloading module? + self.env.env.env_params = getattr( + importlib.import_module(f"gymnax.environments.{self.module}"), "EnvParams" + )(**content) diff --git a/carl/envs/gymnax/classic_control/carl_gymnax_mountaincar.py b/carl/envs/gymnax/classic_control/carl_gymnax_mountaincar.py new file mode 100644 index 00000000..f0421ac4 --- /dev/null +++ b/carl/envs/gymnax/classic_control/carl_gymnax_mountaincar.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv + + +class CARLGymnaxMountainCar(CARLGymnaxEnv): + env_name: str = "MountainCar-v0" + module: str = "classic_control.mountain_car" + + @staticmethod + def get_context_features() -> dict[str, ContextFeature]: + return { + "max_speed": UniformFloatContextFeature( + "max_speed", lower=1e-3, upper=10, default_value=0.07 + ), + "goal_position": UniformFloatContextFeature( + "goal_position", lower=-2, upper=2, default_value=0.45 + ), + "goal_velocity": UniformFloatContextFeature( + "goal_velocity", lower=-10, upper=10, default_value=0 + ), + "force": UniformFloatContextFeature( + "force", lower=-10, upper=10, default_value=0.001 + ), + "gravity": UniformFloatContextFeature( + "gravity", lower=-10, upper=10, default_value=0.0025 + ), + } + + +class CARLGymnaxMountainCarContinuous(CARLGymnaxMountainCar): + env_name: str = "MountainCarContinuous-v0" + module: str = "classic_control.continuous_mountain_car" + + @staticmethod + def get_context_features() -> dict[str, ContextFeature]: + return { + "max_speed": UniformFloatContextFeature( + "max_speed", lower=1e-3, upper=10, default_value=0.07 + ), + "goal_position": UniformFloatContextFeature( + "goal_position", lower=-2, upper=2, default_value=0.45 + ), + "goal_velocity": UniformFloatContextFeature( + "goal_velocity", lower=-10, upper=10, default_value=0 + ), + "power": UniformFloatContextFeature( + "power", lower=1e-6, upper=10, default_value=0.001 + ), + "gravity": UniformFloatContextFeature( + "gravity", lower=-10, upper=10, default_value=0.0025 + ), + } diff --git a/carl/envs/gymnax/classic_control/carl_gymnax_pendulum.py b/carl/envs/gymnax/classic_control/carl_gymnax_pendulum.py new file mode 100644 index 00000000..676477af --- /dev/null +++ b/carl/envs/gymnax/classic_control/carl_gymnax_pendulum.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from carl.context.context_space import ContextFeature, UniformFloatContextFeature +from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv + + +class CARLGymnaxPendulum(CARLGymnaxEnv): + env_name: str = "Pendulum-v1" + module: str = "classic_control.pendulum" + + @staticmethod + def get_context_features() -> dict[str, ContextFeature]: + return { + "dt": UniformFloatContextFeature( + "dt", lower=0.001, upper=10, default_value=0.05 + ), + "g": UniformFloatContextFeature( + "g", lower=-100, upper=100, default_value=10 + ), + "m": UniformFloatContextFeature( + "m", lower=1e-6, upper=100, default_value=1 + ), + "l": UniformFloatContextFeature( + "l", lower=1e-6, upper=100, default_value=1 + ), + "max_speed": UniformFloatContextFeature( + "max_speed", lower=0.08, upper=80, default_value=8 + ), + "max_torque": UniformFloatContextFeature( + "max_torque", lower=0.02, upper=40, default_value=2 + ), + } diff --git a/carl/envs/gymnax/utils.py b/carl/envs/gymnax/utils.py new file mode 100644 index 00000000..454335d1 --- /dev/null +++ b/carl/envs/gymnax/utils.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Any + +import gymnasium +import gymnasium.spaces +import gymnax +from gymnax.environments.environment import Environment, EnvParams +from gymnax.environments.spaces import Space, gymnax_space_to_gym_space +from gymnax.wrappers.gym import GymnaxToGymWrapper + + +# Although this converts to gym, the step API already is for gymnasium +class CustomGymnaxToGymnasiumWrapper(GymnaxToGymWrapper): + def __init__( + self, env: Environment, params: EnvParams | None = None, seed: int | None = None + ): + super().__init__(env, params, seed) + + self._observation_space = SpaceWrapper( + gymnax_space_to_gym_space(self._env.observation_space(self.env_params)) + ) + + @property + def env(self) -> Environment: + return self._env + + @env.setter + def env(self, value: Environment) -> None: + self._env = value + + @property + def observation_space(self) -> gymnasium.Space: + return self._observation_space + + @observation_space.setter + def observation_space(self, value: Space) -> None: + self._observation_space = value + + +class SpaceWrapper(gymnasium.Space): + def __init__(self, space): + self.space = space + + def __getattr__(self, __name: str) -> Any: + return self.space.__getattr__(__name=__name) + + +def make_gymnax_env(env_name: str) -> gymnasium.Env: + # Make gymnax env + env, env_params = gymnax.make(env_id=env_name) + + # Convert gymnax to gymnasium API + env = CustomGymnaxToGymnasiumWrapper(env=env, params=env_params) + + return env diff --git a/changelog.md b/changelog.md index 76d89dae..bc4be264 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,6 @@ +# 1.1.0 +- Add gymnax classic control environments (#90) + # 1.0.0 Major overhaul of the CARL environment - Contexts are stored in each environment's class diff --git a/setup.py b/setup.py index ba5712e8..a90a3dd0 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,9 @@ def read_file(filepath: str) -> str: "dm_control": [ "dm_control>=1.0.3", ], + "gymnax": [ + "gymnax>=0.0.6", + ], "mario": [ "opencv-python>=4.8.0", "torch>=1.9.0", diff --git a/test/test_all_envs.py b/test/test_all_envs.py index 141b06b4..1d1e1240 100644 --- a/test/test_all_envs.py +++ b/test/test_all_envs.py @@ -16,6 +16,7 @@ def test_init_all_envs(self): env = ( # noqa: F841 local variable is assigned to but never used var() ) + _ = env.reset() except Exception as e: print(f"Cannot instantiate {var} environment.") raise e diff --git a/test/test_context_selector.py b/test/test_context_selector.py index 3e135ba8..784e5939 100644 --- a/test/test_context_selector.py +++ b/test/test_context_selector.py @@ -12,7 +12,7 @@ class TestContextSelection(unittest.TestCase): @staticmethod def generate_contexts() -> Dict[Any, Context]: keys = "abc" - context = {"dt": 0.03, "gravity": 10.0, "m": 1.0, "l": 1.8} + context = {"dt": 0.03, "g": 10.0, "m": 1.0, "l": 1.8} contexts = {k: context for k in keys} return contexts