diff --git a/.gitignore b/.gitignore index 04d64c5a..8bbc18a3 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,6 @@ Desktop.ini *claude* *Claude* *CLAUDE* +**/.ipynb_checkpoints/ +**/.DS_Store +**/__pycache__/ diff --git a/src/envs/README.md b/src/envs/README.md index e45c181a..e6b8996e 100644 --- a/src/envs/README.md +++ b/src/envs/README.md @@ -237,6 +237,13 @@ Executes Python code in a sandboxed environment. Demonstrates: See: [`coding_env/README.md`](coding_env/README.md) +### Connect4 Environment +Location: `src/envs/connect4_env/` + +Wraps the `gym-connect4` implementation to provide a turnkey board-game benchmark that follows the OpenEnv API, including typed models, HTTP client, and Docker image. + +See: [`connect4_env/README.md`](connect4_env/README.md) + ## Best Practices ### 1. Type Safety diff --git a/src/envs/connect_four/README.md b/src/envs/connect_four/README.md new file mode 100644 index 00000000..ea4be648 --- /dev/null +++ b/src/envs/connect_four/README.md @@ -0,0 +1,21 @@ +# Connect Four (OpenSpiel) — OpenEnv Wrapper + +This environment wraps **OpenSpiel**’s `connect_four` and exposes an OpenEnv-style API. + +## Observation +- **Board**: `6 x 7` int grid in the _agent’s_ view + - `0` empty, `+1` agent discs (player 0), `-1` opponent discs (player 1). +- **Legal actions**: playable columns `[0..6]`. +- **current_player**: `+1` if agent to move, `-1` otherwise. +- **reward**: scalar, agent centric (`+1` win, `-1` loss, `0` otherwise). + +## Endpoints +- `POST /reset` → `{ observation, state }` +- `POST /step` w/ `{"column": int}` → `{ observation, state }` +- `GET /state` → current metadata +- `POST /close` → cleanup + +## Local run +```bash +pip install "open_spiel>=1.6" fastapi "uvicorn[standard]" numpy +uvicorn src.envs.connect_four.server.app:app --host 0.0.0.0 --port 8020 diff --git a/src/envs/connect_four/__init__.py b/src/envs/connect_four/__init__.py new file mode 100644 index 00000000..16a90cb6 --- /dev/null +++ b/src/envs/connect_four/__init__.py @@ -0,0 +1,9 @@ +from .models import ConnectFourAction, ConnectFourObservation, ConnectFourState +from .client import ConnectFourEnvClient + +__all__ = [ + "ConnectFourAction", + "ConnectFourObservation", + "ConnectFourState", + "ConnectFourEnvClient", +] diff --git a/src/envs/connect_four/client.py b/src/envs/connect_four/client.py new file mode 100644 index 00000000..9c97fd06 --- /dev/null +++ b/src/envs/connect_four/client.py @@ -0,0 +1,40 @@ +from __future__ import annotations +import requests +from typing import Tuple +from .models import ConnectFourAction, ConnectFourObservation, ConnectFourState + + +class ConnectFourEnvClient: + """ + Tiny HTTP client for the Connect Four server. + + Example: + env = ConnectFourEnvClient("http://localhost:8020") + obs, st = env.reset() + obs, st = env.step(ConnectFourAction(column=3)) + """ + def __init__(self, base_url: str): + self.base = base_url.rstrip("/") + + def reset(self) -> Tuple[ConnectFourObservation, ConnectFourState]: + r = requests.post(f"{self.base}/reset", timeout=30) + r.raise_for_status() + payload = r.json() + return ConnectFourObservation(**payload["observation"]), ConnectFourState(**payload["state"]) + + def step(self, action: ConnectFourAction) -> Tuple[ConnectFourObservation, ConnectFourState]: + r = requests.post(f"{self.base}/step", json=action.model_dump(), timeout=30) + r.raise_for_status() + payload = r.json() + return ConnectFourObservation(**payload["observation"]), ConnectFourState(**payload["state"]) + + def state(self) -> ConnectFourState: + r = requests.get(f"{self.base}/state", timeout=15) + r.raise_for_status() + return ConnectFourState(**r.json()) + + def close(self) -> None: + try: + requests.post(f"{self.base}/close", timeout=10) + except Exception: + pass diff --git a/src/envs/connect_four/models.py b/src/envs/connect_four/models.py new file mode 100644 index 00000000..f3d0559d --- /dev/null +++ b/src/envs/connect_four/models.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field + + +class ConnectFourAction(BaseModel): + column: int = Field(..., ge=0, le=6, description="Playable column 0..6") + + +class ConnectFourObservation(BaseModel): + # 6x7 int grid: 0 empty, +1 agent discs, -1 opponent discs + board: List[List[int]] + # list of playable columns (0..6), empty when done=True + legal_actions: List[int] + # +1 if agent (player 0) to move, -1 otherwise + current_player: int + # last column played, or None at the start + last_move: Optional[int] = None + # terminal flag + done: bool + # scalar reward in agent’s perspective: +1 win, -1 loss, 0 else + reward: float + # passthrough metadata + info: Dict[str, Any] = {} + + +class ConnectFourState(BaseModel): + rows: int = 6 + cols: int = 7 + move_count: int = 0 + episode_id: str = "" diff --git a/src/envs/connect_four/server/Dockerfile b/src/envs/connect_four/server/Dockerfile new file mode 100644 index 00000000..2871e990 --- /dev/null +++ b/src/envs/connect_four/server/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.11-slim + +# System basics (git not strictly required for OpenSpiel but handy for debugging) +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git \ + && rm -rf /var/lib/apt/lists/* + +# Python deps +# - open_spiel from PyPI (>=1.6 ships Linux wheels) +# - pin numpy<2.0 for broad compatibility with older stacks +RUN pip install --no-cache-dir "fastapi>=0.112" "uvicorn[standard]>=0.30" "numpy>=1.24,<2.0" "open_spiel>=1.6" + +# Copy project +WORKDIR /app +COPY . /app/ + +# Defaults (override at runtime) +ENV PORT=8020 +ENV OPENSPIEL_GAME=connect_four +ENV CONNECT4_AUTOPLAY_OPPONENT=false +ENV CONNECT4_OPP_POLICY=random + +EXPOSE 8020 +CMD ["sh", "-c", "uvicorn src.envs.connect_four.server.app:app --host 0.0.0.0 --port ${PORT}"] diff --git a/src/envs/connect_four/server/__init__.py b/src/envs/connect_four/server/__init__.py new file mode 100644 index 00000000..9eaacc70 --- /dev/null +++ b/src/envs/connect_four/server/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Connect Four environment server components.""" + +from .connect_four_environment import ConnectFourEnvironment + +__all__ = ["ConnectFourEnvironment"] diff --git a/src/envs/connect_four/server/app.py b/src/envs/connect_four/server/app.py new file mode 100644 index 00000000..7de1dd6b --- /dev/null +++ b/src/envs/connect_four/server/app.py @@ -0,0 +1,70 @@ +from __future__ import annotations +import os +from typing import Optional + +from fastapi import FastAPI +from pydantic import BaseModel + +from ..models import ConnectFourAction, ConnectFourObservation, ConnectFourState +from .connect_four_environment import ( + ConnectFourEnvironment, + ConnectFourConfig, +) + +# ------------ env config from environment variables ------------ +PORT = int(os.getenv("PORT", "8020")) +GAME_STRING = os.getenv("OPENSPIEL_GAME", "connect_four") +AUTO_OPP = os.getenv("CONNECT4_AUTOPLAY_OPPONENT", "false").lower() in {"1", "true", "yes"} +OPP_POLICY = os.getenv("CONNECT4_OPP_POLICY", "random") # random | lowest | highest + +# ------------------------- FastAPI app ------------------------- +app = FastAPI(title="OpenEnv • Connect Four (OpenSpiel)", version="1.0.0") + +_env: Optional[ConnectFourEnvironment] = None +_state = ConnectFourState() + +def _dump(model: BaseModel) -> dict: + return model.model_dump() if hasattr(model, "model_dump") else model.dict() + +def _ensure_env() -> ConnectFourEnvironment: + global _env + if _env is None: + cfg = ConnectFourConfig( + game_string=GAME_STRING, + autoplay_opponent=AUTO_OPP, + opponent_policy=OPP_POLICY, + ) + _env = ConnectFourEnvironment(cfg) + return _env + +# --------------------------- endpoints -------------------------- + +@app.post("/reset") +def reset(): + env = _ensure_env() + obs_dict, st_dict = env.reset() + global _state + _state = ConnectFourState(**st_dict) + return {"observation": _dump(ConnectFourObservation(**obs_dict)), "state": _dump(_state)} + +@app.post("/step") +def step(action: ConnectFourAction): + env = _ensure_env() + obs_dict, st_dict = env.step(action.column) + global _state + _state = ConnectFourState(**st_dict) + return {"observation": _dump(ConnectFourObservation(**obs_dict)), "state": _dump(_state)} + +@app.get("/state") +def state(): + return _dump(_state) + +@app.post("/close") +def close(): + global _env + try: + if _env is not None: + _env.close() + finally: + _env = None + return {"ok": True} diff --git a/src/envs/connect_four/server/connect_four_environment.py b/src/envs/connect_four/server/connect_four_environment.py new file mode 100644 index 00000000..849c97c9 --- /dev/null +++ b/src/envs/connect_four/server/connect_four_environment.py @@ -0,0 +1,218 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +try: + import pyspiel # OpenSpiel +except Exception as e: + raise ImportError( + "open_spiel (pyspiel) is required. Install with `pip install open_spiel`." + ) from e + + +@dataclass +class ConnectFourConfig: + game_string: str = "connect_four" + # If True, the env auto-plays the opponent (player 1) using a trivial policy + # whenever it becomes their turn (keeps a single-agent loop simple). + autoplay_opponent: bool = False + # Opponent policy: "random" | "lowest" | "highest" + opponent_policy: str = "random" + + +class ConnectFourEnvironment: + """OpenSpiel-backed Connect Four with OpenEnv-compatible semantics.""" + + ROWS = 6 + COLS = 7 + + def __init__(self, config: Optional[ConnectFourConfig] = None): + self.config = config or ConnectFourConfig() + self._game = pyspiel.load_game(self.config.game_string) + self._state = self._game.new_initial_state() + + # Agent = player 0; opponent = player 1 + self._agent_player: int = 0 + self._move_count: int = 0 + self._episode_id: str = "" + # cache of reconstructed grid (-1 empty, {0,1} owners) + self._grid_cache: Optional[np.ndarray] = None + + # ----------------------------- API ----------------------------- + + def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if seed is not None: + np.random.seed(seed) + self._state = self._game.new_initial_state() + self._move_count = 0 + self._episode_id = self._new_episode_id() + self._grid_cache = None + obs = self._build_observation(done=False, reward=0.0, info={"engine": "open_spiel"}) + return obs, self._build_state() + + def step(self, column: int) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Apply agent move (column 0..6). Optionally autoplay opponent move.""" + assert 0 <= column < self.COLS, f"column out of range: {column}" + + self._maybe_autoplay_until_agent_turn() + + # Map to OpenSpiel action; legality guard + act = self._column_to_action(column) + legal = self._state.legal_actions() + if act not in legal: + info = {"error": "illegal_action", "legal_columns": self.legal_actions()} + obs = self._build_observation(done=True, reward=-1.0, info=info) + return obs, self._build_state() + + self._state.apply_action(act) + self._move_count += 1 + self._invalidate_grid_cache() + + if self._state.is_terminal(): + reward = self._terminal_reward_for_agent() + obs = self._build_observation(done=True, reward=reward, info={"engine": "open_spiel"}) + return obs, self._build_state() + + if self.config.autoplay_opponent: + self._autoplay_opponent_once() + if self._state.is_terminal(): + reward = self._terminal_reward_for_agent() + obs = self._build_observation(done=True, reward=reward, info={"engine": "open_spiel"}) + return obs, self._build_state() + + obs = self._build_observation(done=False, reward=0.0, info={"engine": "open_spiel"}) + return obs, self._build_state() + + def close(self) -> None: + # No special cleanup required + self._state = self._game.new_initial_state() + self._grid_cache = None + + # --------------------------- helpers --------------------------- + + def legal_actions(self) -> List[int]: + return sorted({self._action_to_column(a) for a in self._state.legal_actions()}) + + def current_player(self) -> int: + return 1 if self._state.current_player() == self._agent_player else -1 + + def board_agent_view(self) -> np.ndarray: + """Return 6x7 board: 0 empty, +1 agent discs, -1 opponent discs.""" + grid = self._reconstruct_grid_from_history() + board = np.zeros_like(grid, dtype=int) + board[grid == -1] = 0 + board[grid == self._agent_player] = 1 + board[(grid != -1) & (grid != self._agent_player)] = -1 + return board + + def _reconstruct_grid_from_history(self) -> np.ndarray: + """Rebuild grid (-1 empty, 0/1 owners) from action history.""" + if self._grid_cache is not None: + return self._grid_cache + grid = np.zeros((self.ROWS, self.COLS), dtype=int) - 1 # -1 empty + player = 0 # starts with player 0 + for act in self._state.history(): + col = self._action_to_column(act) + rr = self._lowest_empty_row(grid, col) + if rr is not None: + grid[rr, col] = player + player = 1 - player + self._grid_cache = grid + return grid + + @staticmethod + def _lowest_empty_row(grid: np.ndarray, col: int) -> Optional[int]: + for r in range(grid.shape[0] - 1, -1, -1): + if grid[r, col] == -1: + return r + return None + + def _invalidate_grid_cache(self) -> None: + self._grid_cache = None + + # ----- action mapping ----- + + def _column_to_action(self, col: int) -> int: + # OpenSpiel uses 0..6 column IDs as actions + # still verify against legal action list in case of variant configs + for a in self._state.legal_actions(): + if self._action_to_column(a) == col: + return a + return col + + @staticmethod + def _action_to_column(action: int) -> int: + return int(action) + + # ----- opponent autoplay ----- + + def _maybe_autoplay_until_agent_turn(self) -> None: + if not self.config.autoplay_opponent: + return + while self._state.current_player() != self._agent_player and not self._state.is_terminal(): + self._autoplay_opponent_once() + + def _autoplay_opponent_once(self) -> None: + if self._state.current_player() == self._agent_player or self._state.is_terminal(): + return + legal = self._state.legal_actions() + if not legal: + return + cols = [self._action_to_column(a) for a in legal] + if self.config.opponent_policy == "lowest": + chosen_col = min(cols) + elif self.config.opponent_policy == "highest": + chosen_col = max(cols) + else: + chosen_col = int(np.random.choice(cols)) + self._state.apply_action(self._column_to_action(chosen_col)) + self._invalidate_grid_cache() + + # ----- rewards ----- + + def _terminal_reward_for_agent(self) -> float: + if not self._state.is_terminal(): + return 0.0 + returns = self._state.returns() + val = float(returns[self._agent_player]) # >0 win, <0 loss, 0 draw + if val > 0: + return 1.0 + if val < 0: + return -1.0 + return 0.0 + + # ----- payloads ----- + + def _build_observation(self, done: bool, reward: float, info: Dict[str, Any]) -> Dict[str, Any]: + board = self.board_agent_view() + obs = { + "board": board.tolist(), + "legal_actions": [] if done else self.legal_actions(), + "current_player": self.current_player() if not done else 1, + "last_move": self._last_move_column(), + "done": bool(done), + "reward": float(reward), + "info": dict(info or {}), + } + return obs + + def _build_state(self) -> Dict[str, Any]: + return { + "rows": self.ROWS, + "cols": self.COLS, + "move_count": self._move_count, + "episode_id": self._episode_id, + } + + def _last_move_column(self) -> Optional[int]: + hist = self._state.history() + if not hist: + return None + return self._action_to_column(hist[-1]) + + @staticmethod + def _new_episode_id() -> str: + import uuid + return str(uuid.uuid4())