From 8323a32fbfc50c9898dc293724cb0d003e256ea3 Mon Sep 17 00:00:00 2001 From: VivekHaridas-01 Date: Sun, 26 Oct 2025 10:51:15 -0400 Subject: [PATCH 1/4] Added Connect4 Env wrapped on Gym --- src/envs/README.md | 7 + src/envs/connect4_env/README.md | 38 ++++ src/envs/connect4_env/__init__.py | 13 ++ src/envs/connect4_env/client.py | 42 ++++ src/envs/connect4_env/models.py | 34 +++ src/envs/connect4_env/server/Dockerfile | 23 ++ src/envs/connect4_env/server/__init__.py | 5 + src/envs/connect4_env/server/app.py | 14 ++ .../server/connect4_environment.py | 204 ++++++++++++++++++ tests/envs/test_connect4_env_smoke.py | 47 ++++ 10 files changed, 427 insertions(+) create mode 100644 src/envs/connect4_env/README.md create mode 100644 src/envs/connect4_env/__init__.py create mode 100644 src/envs/connect4_env/client.py create mode 100644 src/envs/connect4_env/models.py create mode 100644 src/envs/connect4_env/server/Dockerfile create mode 100644 src/envs/connect4_env/server/__init__.py create mode 100644 src/envs/connect4_env/server/app.py create mode 100644 src/envs/connect4_env/server/connect4_environment.py create mode 100644 tests/envs/test_connect4_env_smoke.py diff --git a/src/envs/README.md b/src/envs/README.md index e45c181a..e6b8996e 100644 --- a/src/envs/README.md +++ b/src/envs/README.md @@ -237,6 +237,13 @@ Executes Python code in a sandboxed environment. Demonstrates: See: [`coding_env/README.md`](coding_env/README.md) +### Connect4 Environment +Location: `src/envs/connect4_env/` + +Wraps the `gym-connect4` implementation to provide a turnkey board-game benchmark that follows the OpenEnv API, including typed models, HTTP client, and Docker image. + +See: [`connect4_env/README.md`](connect4_env/README.md) + ## Best Practices ### 1. Type Safety diff --git a/src/envs/connect4_env/README.md b/src/envs/connect4_env/README.md new file mode 100644 index 00000000..f85d477b --- /dev/null +++ b/src/envs/connect4_env/README.md @@ -0,0 +1,38 @@ +# Connect4 Environment + +This environment wraps the [`gym-connect4`](https://github.com/Danielhp95/gym-connect4) implementation inside OpenEnv. It exposes a turn-based 6x7 Connect Four board where the agent plays as player `+1` against the built-in opponent logic supplied by the Gym environment. + +## Action, Observation, State + +| Type | Fields | Description | +| --- | --- | --- | +| `Connect4Action` | `column: int` | 0-based column where the agent drops a disc. | +| `Connect4Observation` | `board: list[list[int]]`
`legal_actions: list[int]`
`current_player: int`
`last_move: Optional[int]`
`info: dict` | Board uses `1` for the agent, `-1` for the opponent, `0` for empty. Legal actions are the playable columns. When `done=True`, `legal_actions` is empty. Any metadata from Gym is forwarded through `info`. | +| `Connect4State` | `episode_id: str`
`step_count: int`
`rows: int`
`cols: int` | Mirrors the generic OpenEnv state and records the board geometry. | + +Rewards from Gym can be scalars or a 2-element vector. The server always scalarizes them into an agent-centric `float` (`r_agent - r_opponent` when two values are supplied). + +## Running the server + +```bash +uvicorn envs.connect4_env.server.app:app --host 0.0.0.0 --port 8000 +``` + +Set `GYM_CONNECT4_ID` if you need a custom Gym registration ID (default `Connect4-v0`). + +## Client usage + +```python +from envs.connect4_env import Connect4Env, Connect4Action + +client = Connect4Env(base_url="http://localhost:8000") + +result = client.reset() +print(result.observation.board) + +while not result.done: + action = Connect4Action(column=result.observation.legal_actions[0]) + result = client.step(action) + +print("Episode reward:", result.reward) +``` diff --git a/src/envs/connect4_env/__init__.py b/src/envs/connect4_env/__init__.py new file mode 100644 index 00000000..7696e631 --- /dev/null +++ b/src/envs/connect4_env/__init__.py @@ -0,0 +1,13 @@ +"""Connect4 OpenEnv package exports.""" + +from .client import Connect4Env +from .models import Connect4Action, Connect4Observation, Connect4State +from .server.connect4_environment import Connect4Environment + +__all__ = ( + "Connect4Action", + "Connect4Observation", + "Connect4State", + "Connect4Env", + "Connect4Environment", +) diff --git a/src/envs/connect4_env/client.py b/src/envs/connect4_env/client.py new file mode 100644 index 00000000..dfc92352 --- /dev/null +++ b/src/envs/connect4_env/client.py @@ -0,0 +1,42 @@ +"""HTTP client for the Connect4 OpenEnv environment.""" + +from __future__ import annotations + +from typing import Any, Dict + +from core.client_types import StepResult +from core.http_env_client import HTTPEnvClient + +from .models import Connect4Action, Connect4Observation, Connect4State + + +class Connect4Env(HTTPEnvClient[Connect4Action, Connect4Observation]): + """Thin HTTP client used by agents to interact with the Connect4 server.""" + + def _step_payload(self, action: Connect4Action) -> Dict[str, Any]: + return {"column": action.column, "metadata": action.metadata} + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Connect4Observation]: + obs_data = payload.get("observation", {}) + observation = Connect4Observation( + board=obs_data.get("board", []), + legal_actions=obs_data.get("legal_actions", []), + current_player=obs_data.get("current_player", 1), + last_move=obs_data.get("last_move"), + info=obs_data.get("info", {}), + done=payload.get("done", False), + reward=payload.get("reward"), + ) + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> Connect4State: + return Connect4State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + rows=payload.get("rows", 6), + cols=payload.get("cols", 7), + ) diff --git a/src/envs/connect4_env/models.py b/src/envs/connect4_env/models.py new file mode 100644 index 00000000..1416d29b --- /dev/null +++ b/src/envs/connect4_env/models.py @@ -0,0 +1,34 @@ +"""Data models for the Connect4 OpenEnv environment.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from core.env_server.types import Action, Observation, State + + +@dataclass(kw_only=True) +class Connect4Action(Action): + """Selects the column (0-indexed) where the agent wants to drop a disc.""" + + column: int + + +@dataclass(kw_only=True) +class Connect4Observation(Observation): + """Observation returned after every step/reset.""" + + board: List[List[int]] # 6x7 grid with 1 (agent), -1 (opponent), 0 (empty) + legal_actions: List[int] + current_player: int + last_move: Optional[int] = None + info: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Connect4State(State): + """Track episode metadata plus board geometry for convenience.""" + + rows: int = 6 + cols: int = 7 diff --git a/src/envs/connect4_env/server/Dockerfile b/src/envs/connect4_env/server/Dockerfile new file mode 100644 index 00000000..817f1123 --- /dev/null +++ b/src/envs/connect4_env/server/Dockerfile @@ -0,0 +1,23 @@ +# Build on top of the shared OpenEnv base image +ARG BASE_IMAGE=openenv-base:latest +FROM ${BASE_IMAGE} + +# Install git for pip VCS installs +RUN apt-get update && apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* + +# Install environment-specific dependencies +RUN pip install --no-cache-dir "gym==0.25.2" "numpy<2.0" \ + git+https://github.com/Danielhp95/gym-connect4 + +# Copy the framework core plus this environment +COPY src/core/ /app/src/core/ +COPY src/envs/connect4_env/ /app/src/envs/connect4_env/ +COPY src/envs/connect4_env/README.md /app/README.md + +# Simple health check - the web UI reuses /health +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the FastAPI server +CMD ["uvicorn", "envs.connect4_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/src/envs/connect4_env/server/__init__.py b/src/envs/connect4_env/server/__init__.py new file mode 100644 index 00000000..aa367f39 --- /dev/null +++ b/src/envs/connect4_env/server/__init__.py @@ -0,0 +1,5 @@ +"""Server package for the Connect4 OpenEnv environment.""" + +from .connect4_environment import Connect4Environment + +__all__ = ("Connect4Environment",) diff --git a/src/envs/connect4_env/server/app.py b/src/envs/connect4_env/server/app.py new file mode 100644 index 00000000..b10352b7 --- /dev/null +++ b/src/envs/connect4_env/server/app.py @@ -0,0 +1,14 @@ +"""FastAPI entrypoint for the Connect4 OpenEnv server.""" + +from core.env_server.http_server import create_app + +from ..models import Connect4Action, Connect4Observation +from .connect4_environment import Connect4Environment + +env = Connect4Environment() +app = create_app(env, Connect4Action, Connect4Observation, env_name="connect4_env") + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/envs/connect4_env/server/connect4_environment.py b/src/envs/connect4_env/server/connect4_environment.py new file mode 100644 index 00000000..fb0a0f48 --- /dev/null +++ b/src/envs/connect4_env/server/connect4_environment.py @@ -0,0 +1,204 @@ +"""Gym-based Connect4 environment wrapped for OpenEnv.""" + +from __future__ import annotations + +import importlib +import os +from typing import Any, Dict, Tuple +from uuid import uuid4 + +import numpy as np + +from core.env_server.interfaces import Environment + +from ..models import Connect4Action, Connect4Observation, Connect4State + +# Ensure the third-party Gym env registers itself if present. +try: # pragma: no cover - optional dependency is best-effort + importlib.import_module("gym_connect4") +except Exception: # noqa: BLE001 + pass + +try: + import gym +except ImportError as exc: # pragma: no cover + raise ImportError( + "The Connect4 environment requires gym>=0.25. " + "Install it inside your Docker image or development venv." + ) from exc + + +def _scalarize_reward(reward: Any) -> float: + """Map scalar, vector, or ndarray rewards into a single float.""" + if isinstance(reward, (list, tuple, np.ndarray)): + arr = np.asarray(reward, dtype=float) + if arr.shape == (2,): + return float(arr[0] - arr[1]) + return float(arr.sum()) + return float(reward) + + +def _normalize_board(obs: Any) -> Tuple[np.ndarray, Dict[str, Any]]: + """ + Convert arbitrary Connect4 observations into a canonical 6x7 np.ndarray. + + Supports: (obs, info) tuples, 2x6x7 one-hot planes, 6x7x2 one-hot tensors, + or per-cell vectors embedded in object arrays. + """ + info: Dict[str, Any] = {} + board = obs + if isinstance(obs, tuple) and len(obs) == 2: + board, info = obs + + arr = np.array(board, dtype=object) + + if arr.ndim == 2 and arr.dtype != object: + return arr.astype(int), info + + if arr.ndim == 3 and arr.dtype != object and arr.shape[0] == 2: + return (arr[0].astype(int) - arr[1].astype(int)), info + + if arr.ndim == 3 and arr.dtype != object and arr.shape[2] == 2: + return (arr[:, :, 0].astype(int) - arr[:, :, 1].astype(int)), info + + if ( + arr.ndim == 4 + and arr.dtype != object + and arr.shape[0] >= 1 + and arr.shape[1] == 3 + ): + # gym-connect4 returns a list of per-player 3-plane tensors with shape + # (players, channels=3, width, height). Convert the first player's view + # (agent perspective) into a signed board matrix. + player_view = arr[0] # shape (3, width, height) + pieces = player_view[1].astype(int) - player_view[2].astype(int) + # Convert to (rows, cols) with row zero on top. + return pieces.T, info + + if arr.ndim == 2 and arr.dtype == object: + h, w = arr.shape + out = np.zeros((h, w), dtype=int) + for r in range(h): + for c in range(w): + val = np.asarray(arr[r, c], dtype=int).ravel() + if val.size == 2: + out[r, c] = int(val[0] - val[1]) + elif val.size == 1: + out[r, c] = int(val[0]) + return out, info + + # Fallback: best effort for mismatched shapes + try: # pragma: no cover - defensive branch + if arr.ndim == 3 and arr.shape[0] == 2: + return (arr[0].astype(int) - arr[1].astype(int)), info + if arr.ndim == 3 and arr.shape[2] == 2: + return (arr[:, :, 0].astype(int) - arr[:, :, 1].astype(int)), info + except Exception: # noqa: BLE001 + pass + + return np.zeros((6, 7), dtype=int), info + + +def _legal_actions(board: np.ndarray) -> list[int]: + return [c for c in range(board.shape[1]) if board[0, c] == 0] + + +def _current_player(info: Dict[str, Any], board: np.ndarray) -> int: + try: + cp = int(info.get("current_player", 0)) + if cp in (1, -1): + return cp + except Exception: # noqa: BLE001 + pass + + p1 = int((board == 1).sum()) + p2 = int((board == -1).sum()) + return 1 if p1 == p2 else -1 + + +class Connect4Environment(Environment): + """Wrap the gym-connect4 environment so it can be served over HTTP.""" + + def __init__(self, gym_id: str | None = None): + super().__init__() + self._gym_id = gym_id or os.getenv("GYM_CONNECT4_ID", "Connect4-v0") + self._env: gym.Env | None = None + self._state = Connect4State() + + def _ensure_env(self) -> gym.Env: + if self._env is None: + self._env = gym.make(self._gym_id) + return self._env + + def reset(self) -> Connect4Observation: + env = self._ensure_env() + raw_obs = env.reset() + board, info = _normalize_board(raw_obs) + rows, cols = board.shape + self._state = Connect4State( + episode_id=str(uuid4()), + step_count=0, + rows=rows, + cols=cols, + ) + + legal_actions = info.get("legal_actions") if info else None + if legal_actions is None: + legal_actions = _legal_actions(board) + + return Connect4Observation( + board=board.tolist(), + legal_actions=list(legal_actions), + current_player=_current_player(info, board), + last_move=info.get("last_move"), + reward=0.0, + done=False, + info=info, + ) + + def step(self, action: Connect4Action) -> Connect4Observation: # type: ignore[override] + env = self._ensure_env() + result = env.step(int(action.column)) + + # Gym 0.25 returns 4-tuple, 0.26+ returns 5-tuple. + if isinstance(result, tuple) and len(result) == 5: + obs, reward, terminated, truncated, info = result + elif isinstance(result, tuple) and len(result) == 4: + obs, reward, done, info = result + terminated, truncated = bool(done), False + else: # pragma: no cover - defensive branch + raise RuntimeError( + f"Unexpected Gym step return type for Connect4: {type(result)}" + ) + + done = bool(terminated or truncated) + board, info2 = _normalize_board(obs) + merged_info: Dict[str, Any] = info or {} + merged_info.update(info2 or {}) + + self._state.step_count += 1 + + legal_actions = merged_info.get("legal_actions") + if done: + legal_actions = [] + elif legal_actions is None: + legal_actions = _legal_actions(board) + + return Connect4Observation( + board=board.tolist(), + legal_actions=list(legal_actions), + current_player=_current_player(merged_info, board), + last_move=merged_info.get("last_move"), + done=done, + reward=_scalarize_reward(reward), + info=merged_info, + ) + + @property + def state(self) -> Connect4State: + return self._state + + def close(self) -> None: + if self._env is not None and hasattr(self._env, "close"): + self._env.close() + self._env = None diff --git a/tests/envs/test_connect4_env_smoke.py b/tests/envs/test_connect4_env_smoke.py new file mode 100644 index 00000000..8e5ef220 --- /dev/null +++ b/tests/envs/test_connect4_env_smoke.py @@ -0,0 +1,47 @@ +"""Basic smoke tests for the Connect4 OpenEnv environment.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Ensure "src" is on the import path when tests run via pytest. +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + +import pytest + +pytest.importorskip("gym") + +from envs.connect4_env.models import Connect4Action # noqa: E402 +from envs.connect4_env.server.connect4_environment import Connect4Environment # noqa: E402 + + +def _assert_board_shape(board: list[list[int]], rows: int, cols: int) -> None: + assert len(board) == rows + assert all(len(row) == cols for row in board) + + +def test_connect4_environment_smoke_run() -> None: + """Reset and step through a short sequence to ensure env wiring works.""" + env = Connect4Environment() + + obs = env.reset() + _assert_board_shape(obs.board, env.state.rows, env.state.cols) + assert obs.legal_actions, "Reset should expose at least one legal move" + assert all(0 <= c < env.state.cols for c in obs.legal_actions) + assert env.state.step_count == 0 + + # Take a handful of legal moves; stop early if the episode terminates. + max_steps = env.state.rows * 2 # Plenty to detect regressions without full episode + for expected_step in range(1, max_steps + 1): + move = obs.legal_actions[0] + obs = env.step(Connect4Action(column=move)) + assert env.state.step_count == expected_step + _assert_board_shape(obs.board, env.state.rows, env.state.cols) + assert isinstance(obs.reward, float) + if obs.done: + break + assert obs.legal_actions, "Episode should offer moves until done" + assert all(0 <= c < env.state.cols for c in obs.legal_actions) + + env.close() From b57ce6d97a93f8f09c646d96c4bffa9d9ce7ec2a Mon Sep 17 00:00:00 2001 From: VivekHaridas-01 Date: Mon, 27 Oct 2025 17:13:00 -0400 Subject: [PATCH 2/4] chore: ignore ipynb checkpoints, DS_Store, and __pycache__ --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 04d64c5a..8bbc18a3 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,6 @@ Desktop.ini *claude* *Claude* *CLAUDE* +**/.ipynb_checkpoints/ +**/.DS_Store +**/__pycache__/ From 5bbc165da225d327277154ae6e2b0b8e4286c1ab Mon Sep 17 00:00:00 2001 From: VivekHaridas-01 Date: Mon, 27 Oct 2025 17:13:57 -0400 Subject: [PATCH 3/4] chore: remove old connect4_env implementation --- src/envs/connect4_env/README.md | 38 ---- src/envs/connect4_env/__init__.py | 13 -- src/envs/connect4_env/client.py | 42 ---- src/envs/connect4_env/models.py | 34 --- src/envs/connect4_env/server/Dockerfile | 23 -- src/envs/connect4_env/server/__init__.py | 5 - src/envs/connect4_env/server/app.py | 14 -- .../server/connect4_environment.py | 204 ------------------ tests/envs/test_connect4_env_smoke.py | 47 ---- 9 files changed, 420 deletions(-) delete mode 100644 src/envs/connect4_env/README.md delete mode 100644 src/envs/connect4_env/__init__.py delete mode 100644 src/envs/connect4_env/client.py delete mode 100644 src/envs/connect4_env/models.py delete mode 100644 src/envs/connect4_env/server/Dockerfile delete mode 100644 src/envs/connect4_env/server/__init__.py delete mode 100644 src/envs/connect4_env/server/app.py delete mode 100644 src/envs/connect4_env/server/connect4_environment.py delete mode 100644 tests/envs/test_connect4_env_smoke.py diff --git a/src/envs/connect4_env/README.md b/src/envs/connect4_env/README.md deleted file mode 100644 index f85d477b..00000000 --- a/src/envs/connect4_env/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Connect4 Environment - -This environment wraps the [`gym-connect4`](https://github.com/Danielhp95/gym-connect4) implementation inside OpenEnv. It exposes a turn-based 6x7 Connect Four board where the agent plays as player `+1` against the built-in opponent logic supplied by the Gym environment. - -## Action, Observation, State - -| Type | Fields | Description | -| --- | --- | --- | -| `Connect4Action` | `column: int` | 0-based column where the agent drops a disc. | -| `Connect4Observation` | `board: list[list[int]]`
`legal_actions: list[int]`
`current_player: int`
`last_move: Optional[int]`
`info: dict` | Board uses `1` for the agent, `-1` for the opponent, `0` for empty. Legal actions are the playable columns. When `done=True`, `legal_actions` is empty. Any metadata from Gym is forwarded through `info`. | -| `Connect4State` | `episode_id: str`
`step_count: int`
`rows: int`
`cols: int` | Mirrors the generic OpenEnv state and records the board geometry. | - -Rewards from Gym can be scalars or a 2-element vector. The server always scalarizes them into an agent-centric `float` (`r_agent - r_opponent` when two values are supplied). - -## Running the server - -```bash -uvicorn envs.connect4_env.server.app:app --host 0.0.0.0 --port 8000 -``` - -Set `GYM_CONNECT4_ID` if you need a custom Gym registration ID (default `Connect4-v0`). - -## Client usage - -```python -from envs.connect4_env import Connect4Env, Connect4Action - -client = Connect4Env(base_url="http://localhost:8000") - -result = client.reset() -print(result.observation.board) - -while not result.done: - action = Connect4Action(column=result.observation.legal_actions[0]) - result = client.step(action) - -print("Episode reward:", result.reward) -``` diff --git a/src/envs/connect4_env/__init__.py b/src/envs/connect4_env/__init__.py deleted file mode 100644 index 7696e631..00000000 --- a/src/envs/connect4_env/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Connect4 OpenEnv package exports.""" - -from .client import Connect4Env -from .models import Connect4Action, Connect4Observation, Connect4State -from .server.connect4_environment import Connect4Environment - -__all__ = ( - "Connect4Action", - "Connect4Observation", - "Connect4State", - "Connect4Env", - "Connect4Environment", -) diff --git a/src/envs/connect4_env/client.py b/src/envs/connect4_env/client.py deleted file mode 100644 index dfc92352..00000000 --- a/src/envs/connect4_env/client.py +++ /dev/null @@ -1,42 +0,0 @@ -"""HTTP client for the Connect4 OpenEnv environment.""" - -from __future__ import annotations - -from typing import Any, Dict - -from core.client_types import StepResult -from core.http_env_client import HTTPEnvClient - -from .models import Connect4Action, Connect4Observation, Connect4State - - -class Connect4Env(HTTPEnvClient[Connect4Action, Connect4Observation]): - """Thin HTTP client used by agents to interact with the Connect4 server.""" - - def _step_payload(self, action: Connect4Action) -> Dict[str, Any]: - return {"column": action.column, "metadata": action.metadata} - - def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Connect4Observation]: - obs_data = payload.get("observation", {}) - observation = Connect4Observation( - board=obs_data.get("board", []), - legal_actions=obs_data.get("legal_actions", []), - current_player=obs_data.get("current_player", 1), - last_move=obs_data.get("last_move"), - info=obs_data.get("info", {}), - done=payload.get("done", False), - reward=payload.get("reward"), - ) - return StepResult( - observation=observation, - reward=payload.get("reward"), - done=payload.get("done", False), - ) - - def _parse_state(self, payload: Dict[str, Any]) -> Connect4State: - return Connect4State( - episode_id=payload.get("episode_id"), - step_count=payload.get("step_count", 0), - rows=payload.get("rows", 6), - cols=payload.get("cols", 7), - ) diff --git a/src/envs/connect4_env/models.py b/src/envs/connect4_env/models.py deleted file mode 100644 index 1416d29b..00000000 --- a/src/envs/connect4_env/models.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Data models for the Connect4 OpenEnv environment.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -from core.env_server.types import Action, Observation, State - - -@dataclass(kw_only=True) -class Connect4Action(Action): - """Selects the column (0-indexed) where the agent wants to drop a disc.""" - - column: int - - -@dataclass(kw_only=True) -class Connect4Observation(Observation): - """Observation returned after every step/reset.""" - - board: List[List[int]] # 6x7 grid with 1 (agent), -1 (opponent), 0 (empty) - legal_actions: List[int] - current_player: int - last_move: Optional[int] = None - info: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class Connect4State(State): - """Track episode metadata plus board geometry for convenience.""" - - rows: int = 6 - cols: int = 7 diff --git a/src/envs/connect4_env/server/Dockerfile b/src/envs/connect4_env/server/Dockerfile deleted file mode 100644 index 817f1123..00000000 --- a/src/envs/connect4_env/server/Dockerfile +++ /dev/null @@ -1,23 +0,0 @@ -# Build on top of the shared OpenEnv base image -ARG BASE_IMAGE=openenv-base:latest -FROM ${BASE_IMAGE} - -# Install git for pip VCS installs -RUN apt-get update && apt-get install -y --no-install-recommends git && \ - rm -rf /var/lib/apt/lists/* - -# Install environment-specific dependencies -RUN pip install --no-cache-dir "gym==0.25.2" "numpy<2.0" \ - git+https://github.com/Danielhp95/gym-connect4 - -# Copy the framework core plus this environment -COPY src/core/ /app/src/core/ -COPY src/envs/connect4_env/ /app/src/envs/connect4_env/ -COPY src/envs/connect4_env/README.md /app/README.md - -# Simple health check - the web UI reuses /health -HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:8000/health || exit 1 - -# Run the FastAPI server -CMD ["uvicorn", "envs.connect4_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/src/envs/connect4_env/server/__init__.py b/src/envs/connect4_env/server/__init__.py deleted file mode 100644 index aa367f39..00000000 --- a/src/envs/connect4_env/server/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Server package for the Connect4 OpenEnv environment.""" - -from .connect4_environment import Connect4Environment - -__all__ = ("Connect4Environment",) diff --git a/src/envs/connect4_env/server/app.py b/src/envs/connect4_env/server/app.py deleted file mode 100644 index b10352b7..00000000 --- a/src/envs/connect4_env/server/app.py +++ /dev/null @@ -1,14 +0,0 @@ -"""FastAPI entrypoint for the Connect4 OpenEnv server.""" - -from core.env_server.http_server import create_app - -from ..models import Connect4Action, Connect4Observation -from .connect4_environment import Connect4Environment - -env = Connect4Environment() -app = create_app(env, Connect4Action, Connect4Observation, env_name="connect4_env") - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/envs/connect4_env/server/connect4_environment.py b/src/envs/connect4_env/server/connect4_environment.py deleted file mode 100644 index fb0a0f48..00000000 --- a/src/envs/connect4_env/server/connect4_environment.py +++ /dev/null @@ -1,204 +0,0 @@ -"""Gym-based Connect4 environment wrapped for OpenEnv.""" - -from __future__ import annotations - -import importlib -import os -from typing import Any, Dict, Tuple -from uuid import uuid4 - -import numpy as np - -from core.env_server.interfaces import Environment - -from ..models import Connect4Action, Connect4Observation, Connect4State - -# Ensure the third-party Gym env registers itself if present. -try: # pragma: no cover - optional dependency is best-effort - importlib.import_module("gym_connect4") -except Exception: # noqa: BLE001 - pass - -try: - import gym -except ImportError as exc: # pragma: no cover - raise ImportError( - "The Connect4 environment requires gym>=0.25. " - "Install it inside your Docker image or development venv." - ) from exc - - -def _scalarize_reward(reward: Any) -> float: - """Map scalar, vector, or ndarray rewards into a single float.""" - if isinstance(reward, (list, tuple, np.ndarray)): - arr = np.asarray(reward, dtype=float) - if arr.shape == (2,): - return float(arr[0] - arr[1]) - return float(arr.sum()) - return float(reward) - - -def _normalize_board(obs: Any) -> Tuple[np.ndarray, Dict[str, Any]]: - """ - Convert arbitrary Connect4 observations into a canonical 6x7 np.ndarray. - - Supports: (obs, info) tuples, 2x6x7 one-hot planes, 6x7x2 one-hot tensors, - or per-cell vectors embedded in object arrays. - """ - info: Dict[str, Any] = {} - board = obs - if isinstance(obs, tuple) and len(obs) == 2: - board, info = obs - - arr = np.array(board, dtype=object) - - if arr.ndim == 2 and arr.dtype != object: - return arr.astype(int), info - - if arr.ndim == 3 and arr.dtype != object and arr.shape[0] == 2: - return (arr[0].astype(int) - arr[1].astype(int)), info - - if arr.ndim == 3 and arr.dtype != object and arr.shape[2] == 2: - return (arr[:, :, 0].astype(int) - arr[:, :, 1].astype(int)), info - - if ( - arr.ndim == 4 - and arr.dtype != object - and arr.shape[0] >= 1 - and arr.shape[1] == 3 - ): - # gym-connect4 returns a list of per-player 3-plane tensors with shape - # (players, channels=3, width, height). Convert the first player's view - # (agent perspective) into a signed board matrix. - player_view = arr[0] # shape (3, width, height) - pieces = player_view[1].astype(int) - player_view[2].astype(int) - # Convert to (rows, cols) with row zero on top. - return pieces.T, info - - if arr.ndim == 2 and arr.dtype == object: - h, w = arr.shape - out = np.zeros((h, w), dtype=int) - for r in range(h): - for c in range(w): - val = np.asarray(arr[r, c], dtype=int).ravel() - if val.size == 2: - out[r, c] = int(val[0] - val[1]) - elif val.size == 1: - out[r, c] = int(val[0]) - return out, info - - # Fallback: best effort for mismatched shapes - try: # pragma: no cover - defensive branch - if arr.ndim == 3 and arr.shape[0] == 2: - return (arr[0].astype(int) - arr[1].astype(int)), info - if arr.ndim == 3 and arr.shape[2] == 2: - return (arr[:, :, 0].astype(int) - arr[:, :, 1].astype(int)), info - except Exception: # noqa: BLE001 - pass - - return np.zeros((6, 7), dtype=int), info - - -def _legal_actions(board: np.ndarray) -> list[int]: - return [c for c in range(board.shape[1]) if board[0, c] == 0] - - -def _current_player(info: Dict[str, Any], board: np.ndarray) -> int: - try: - cp = int(info.get("current_player", 0)) - if cp in (1, -1): - return cp - except Exception: # noqa: BLE001 - pass - - p1 = int((board == 1).sum()) - p2 = int((board == -1).sum()) - return 1 if p1 == p2 else -1 - - -class Connect4Environment(Environment): - """Wrap the gym-connect4 environment so it can be served over HTTP.""" - - def __init__(self, gym_id: str | None = None): - super().__init__() - self._gym_id = gym_id or os.getenv("GYM_CONNECT4_ID", "Connect4-v0") - self._env: gym.Env | None = None - self._state = Connect4State() - - def _ensure_env(self) -> gym.Env: - if self._env is None: - self._env = gym.make(self._gym_id) - return self._env - - def reset(self) -> Connect4Observation: - env = self._ensure_env() - raw_obs = env.reset() - board, info = _normalize_board(raw_obs) - rows, cols = board.shape - self._state = Connect4State( - episode_id=str(uuid4()), - step_count=0, - rows=rows, - cols=cols, - ) - - legal_actions = info.get("legal_actions") if info else None - if legal_actions is None: - legal_actions = _legal_actions(board) - - return Connect4Observation( - board=board.tolist(), - legal_actions=list(legal_actions), - current_player=_current_player(info, board), - last_move=info.get("last_move"), - reward=0.0, - done=False, - info=info, - ) - - def step(self, action: Connect4Action) -> Connect4Observation: # type: ignore[override] - env = self._ensure_env() - result = env.step(int(action.column)) - - # Gym 0.25 returns 4-tuple, 0.26+ returns 5-tuple. - if isinstance(result, tuple) and len(result) == 5: - obs, reward, terminated, truncated, info = result - elif isinstance(result, tuple) and len(result) == 4: - obs, reward, done, info = result - terminated, truncated = bool(done), False - else: # pragma: no cover - defensive branch - raise RuntimeError( - f"Unexpected Gym step return type for Connect4: {type(result)}" - ) - - done = bool(terminated or truncated) - board, info2 = _normalize_board(obs) - merged_info: Dict[str, Any] = info or {} - merged_info.update(info2 or {}) - - self._state.step_count += 1 - - legal_actions = merged_info.get("legal_actions") - if done: - legal_actions = [] - elif legal_actions is None: - legal_actions = _legal_actions(board) - - return Connect4Observation( - board=board.tolist(), - legal_actions=list(legal_actions), - current_player=_current_player(merged_info, board), - last_move=merged_info.get("last_move"), - done=done, - reward=_scalarize_reward(reward), - info=merged_info, - ) - - @property - def state(self) -> Connect4State: - return self._state - - def close(self) -> None: - if self._env is not None and hasattr(self._env, "close"): - self._env.close() - self._env = None diff --git a/tests/envs/test_connect4_env_smoke.py b/tests/envs/test_connect4_env_smoke.py deleted file mode 100644 index 8e5ef220..00000000 --- a/tests/envs/test_connect4_env_smoke.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Basic smoke tests for the Connect4 OpenEnv environment.""" - -from __future__ import annotations - -import sys -from pathlib import Path - -# Ensure "src" is on the import path when tests run via pytest. -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -import pytest - -pytest.importorskip("gym") - -from envs.connect4_env.models import Connect4Action # noqa: E402 -from envs.connect4_env.server.connect4_environment import Connect4Environment # noqa: E402 - - -def _assert_board_shape(board: list[list[int]], rows: int, cols: int) -> None: - assert len(board) == rows - assert all(len(row) == cols for row in board) - - -def test_connect4_environment_smoke_run() -> None: - """Reset and step through a short sequence to ensure env wiring works.""" - env = Connect4Environment() - - obs = env.reset() - _assert_board_shape(obs.board, env.state.rows, env.state.cols) - assert obs.legal_actions, "Reset should expose at least one legal move" - assert all(0 <= c < env.state.cols for c in obs.legal_actions) - assert env.state.step_count == 0 - - # Take a handful of legal moves; stop early if the episode terminates. - max_steps = env.state.rows * 2 # Plenty to detect regressions without full episode - for expected_step in range(1, max_steps + 1): - move = obs.legal_actions[0] - obs = env.step(Connect4Action(column=move)) - assert env.state.step_count == expected_step - _assert_board_shape(obs.board, env.state.rows, env.state.cols) - assert isinstance(obs.reward, float) - if obs.done: - break - assert obs.legal_actions, "Episode should offer moves until done" - assert all(0 <= c < env.state.cols for c in obs.legal_actions) - - env.close() From f2a248afc485f1d6b9be5beb5454f10ecd477a2d Mon Sep 17 00:00:00 2001 From: VivekHaridas-01 Date: Mon, 27 Oct 2025 17:16:15 -0400 Subject: [PATCH 4/4] feat(env/connect_four): new implementation, updated docs and server files --- src/envs/connect_four/README.md | 21 ++ src/envs/connect_four/__init__.py | 9 + src/envs/connect_four/client.py | 40 ++++ src/envs/connect_four/models.py | 31 +++ src/envs/connect_four/server/Dockerfile | 24 ++ src/envs/connect_four/server/__init__.py | 11 + src/envs/connect_four/server/app.py | 70 ++++++ .../server/connect_four_environment.py | 218 ++++++++++++++++++ 8 files changed, 424 insertions(+) create mode 100644 src/envs/connect_four/README.md create mode 100644 src/envs/connect_four/__init__.py create mode 100644 src/envs/connect_four/client.py create mode 100644 src/envs/connect_four/models.py create mode 100644 src/envs/connect_four/server/Dockerfile create mode 100644 src/envs/connect_four/server/__init__.py create mode 100644 src/envs/connect_four/server/app.py create mode 100644 src/envs/connect_four/server/connect_four_environment.py diff --git a/src/envs/connect_four/README.md b/src/envs/connect_four/README.md new file mode 100644 index 00000000..ea4be648 --- /dev/null +++ b/src/envs/connect_four/README.md @@ -0,0 +1,21 @@ +# Connect Four (OpenSpiel) — OpenEnv Wrapper + +This environment wraps **OpenSpiel**’s `connect_four` and exposes an OpenEnv-style API. + +## Observation +- **Board**: `6 x 7` int grid in the _agent’s_ view + - `0` empty, `+1` agent discs (player 0), `-1` opponent discs (player 1). +- **Legal actions**: playable columns `[0..6]`. +- **current_player**: `+1` if agent to move, `-1` otherwise. +- **reward**: scalar, agent centric (`+1` win, `-1` loss, `0` otherwise). + +## Endpoints +- `POST /reset` → `{ observation, state }` +- `POST /step` w/ `{"column": int}` → `{ observation, state }` +- `GET /state` → current metadata +- `POST /close` → cleanup + +## Local run +```bash +pip install "open_spiel>=1.6" fastapi "uvicorn[standard]" numpy +uvicorn src.envs.connect_four.server.app:app --host 0.0.0.0 --port 8020 diff --git a/src/envs/connect_four/__init__.py b/src/envs/connect_four/__init__.py new file mode 100644 index 00000000..16a90cb6 --- /dev/null +++ b/src/envs/connect_four/__init__.py @@ -0,0 +1,9 @@ +from .models import ConnectFourAction, ConnectFourObservation, ConnectFourState +from .client import ConnectFourEnvClient + +__all__ = [ + "ConnectFourAction", + "ConnectFourObservation", + "ConnectFourState", + "ConnectFourEnvClient", +] diff --git a/src/envs/connect_four/client.py b/src/envs/connect_four/client.py new file mode 100644 index 00000000..9c97fd06 --- /dev/null +++ b/src/envs/connect_four/client.py @@ -0,0 +1,40 @@ +from __future__ import annotations +import requests +from typing import Tuple +from .models import ConnectFourAction, ConnectFourObservation, ConnectFourState + + +class ConnectFourEnvClient: + """ + Tiny HTTP client for the Connect Four server. + + Example: + env = ConnectFourEnvClient("http://localhost:8020") + obs, st = env.reset() + obs, st = env.step(ConnectFourAction(column=3)) + """ + def __init__(self, base_url: str): + self.base = base_url.rstrip("/") + + def reset(self) -> Tuple[ConnectFourObservation, ConnectFourState]: + r = requests.post(f"{self.base}/reset", timeout=30) + r.raise_for_status() + payload = r.json() + return ConnectFourObservation(**payload["observation"]), ConnectFourState(**payload["state"]) + + def step(self, action: ConnectFourAction) -> Tuple[ConnectFourObservation, ConnectFourState]: + r = requests.post(f"{self.base}/step", json=action.model_dump(), timeout=30) + r.raise_for_status() + payload = r.json() + return ConnectFourObservation(**payload["observation"]), ConnectFourState(**payload["state"]) + + def state(self) -> ConnectFourState: + r = requests.get(f"{self.base}/state", timeout=15) + r.raise_for_status() + return ConnectFourState(**r.json()) + + def close(self) -> None: + try: + requests.post(f"{self.base}/close", timeout=10) + except Exception: + pass diff --git a/src/envs/connect_four/models.py b/src/envs/connect_four/models.py new file mode 100644 index 00000000..f3d0559d --- /dev/null +++ b/src/envs/connect_four/models.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field + + +class ConnectFourAction(BaseModel): + column: int = Field(..., ge=0, le=6, description="Playable column 0..6") + + +class ConnectFourObservation(BaseModel): + # 6x7 int grid: 0 empty, +1 agent discs, -1 opponent discs + board: List[List[int]] + # list of playable columns (0..6), empty when done=True + legal_actions: List[int] + # +1 if agent (player 0) to move, -1 otherwise + current_player: int + # last column played, or None at the start + last_move: Optional[int] = None + # terminal flag + done: bool + # scalar reward in agent’s perspective: +1 win, -1 loss, 0 else + reward: float + # passthrough metadata + info: Dict[str, Any] = {} + + +class ConnectFourState(BaseModel): + rows: int = 6 + cols: int = 7 + move_count: int = 0 + episode_id: str = "" diff --git a/src/envs/connect_four/server/Dockerfile b/src/envs/connect_four/server/Dockerfile new file mode 100644 index 00000000..2871e990 --- /dev/null +++ b/src/envs/connect_four/server/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.11-slim + +# System basics (git not strictly required for OpenSpiel but handy for debugging) +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git \ + && rm -rf /var/lib/apt/lists/* + +# Python deps +# - open_spiel from PyPI (>=1.6 ships Linux wheels) +# - pin numpy<2.0 for broad compatibility with older stacks +RUN pip install --no-cache-dir "fastapi>=0.112" "uvicorn[standard]>=0.30" "numpy>=1.24,<2.0" "open_spiel>=1.6" + +# Copy project +WORKDIR /app +COPY . /app/ + +# Defaults (override at runtime) +ENV PORT=8020 +ENV OPENSPIEL_GAME=connect_four +ENV CONNECT4_AUTOPLAY_OPPONENT=false +ENV CONNECT4_OPP_POLICY=random + +EXPOSE 8020 +CMD ["sh", "-c", "uvicorn src.envs.connect_four.server.app:app --host 0.0.0.0 --port ${PORT}"] diff --git a/src/envs/connect_four/server/__init__.py b/src/envs/connect_four/server/__init__.py new file mode 100644 index 00000000..9eaacc70 --- /dev/null +++ b/src/envs/connect_four/server/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Connect Four environment server components.""" + +from .connect_four_environment import ConnectFourEnvironment + +__all__ = ["ConnectFourEnvironment"] diff --git a/src/envs/connect_four/server/app.py b/src/envs/connect_four/server/app.py new file mode 100644 index 00000000..7de1dd6b --- /dev/null +++ b/src/envs/connect_four/server/app.py @@ -0,0 +1,70 @@ +from __future__ import annotations +import os +from typing import Optional + +from fastapi import FastAPI +from pydantic import BaseModel + +from ..models import ConnectFourAction, ConnectFourObservation, ConnectFourState +from .connect_four_environment import ( + ConnectFourEnvironment, + ConnectFourConfig, +) + +# ------------ env config from environment variables ------------ +PORT = int(os.getenv("PORT", "8020")) +GAME_STRING = os.getenv("OPENSPIEL_GAME", "connect_four") +AUTO_OPP = os.getenv("CONNECT4_AUTOPLAY_OPPONENT", "false").lower() in {"1", "true", "yes"} +OPP_POLICY = os.getenv("CONNECT4_OPP_POLICY", "random") # random | lowest | highest + +# ------------------------- FastAPI app ------------------------- +app = FastAPI(title="OpenEnv • Connect Four (OpenSpiel)", version="1.0.0") + +_env: Optional[ConnectFourEnvironment] = None +_state = ConnectFourState() + +def _dump(model: BaseModel) -> dict: + return model.model_dump() if hasattr(model, "model_dump") else model.dict() + +def _ensure_env() -> ConnectFourEnvironment: + global _env + if _env is None: + cfg = ConnectFourConfig( + game_string=GAME_STRING, + autoplay_opponent=AUTO_OPP, + opponent_policy=OPP_POLICY, + ) + _env = ConnectFourEnvironment(cfg) + return _env + +# --------------------------- endpoints -------------------------- + +@app.post("/reset") +def reset(): + env = _ensure_env() + obs_dict, st_dict = env.reset() + global _state + _state = ConnectFourState(**st_dict) + return {"observation": _dump(ConnectFourObservation(**obs_dict)), "state": _dump(_state)} + +@app.post("/step") +def step(action: ConnectFourAction): + env = _ensure_env() + obs_dict, st_dict = env.step(action.column) + global _state + _state = ConnectFourState(**st_dict) + return {"observation": _dump(ConnectFourObservation(**obs_dict)), "state": _dump(_state)} + +@app.get("/state") +def state(): + return _dump(_state) + +@app.post("/close") +def close(): + global _env + try: + if _env is not None: + _env.close() + finally: + _env = None + return {"ok": True} diff --git a/src/envs/connect_four/server/connect_four_environment.py b/src/envs/connect_four/server/connect_four_environment.py new file mode 100644 index 00000000..849c97c9 --- /dev/null +++ b/src/envs/connect_four/server/connect_four_environment.py @@ -0,0 +1,218 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +try: + import pyspiel # OpenSpiel +except Exception as e: + raise ImportError( + "open_spiel (pyspiel) is required. Install with `pip install open_spiel`." + ) from e + + +@dataclass +class ConnectFourConfig: + game_string: str = "connect_four" + # If True, the env auto-plays the opponent (player 1) using a trivial policy + # whenever it becomes their turn (keeps a single-agent loop simple). + autoplay_opponent: bool = False + # Opponent policy: "random" | "lowest" | "highest" + opponent_policy: str = "random" + + +class ConnectFourEnvironment: + """OpenSpiel-backed Connect Four with OpenEnv-compatible semantics.""" + + ROWS = 6 + COLS = 7 + + def __init__(self, config: Optional[ConnectFourConfig] = None): + self.config = config or ConnectFourConfig() + self._game = pyspiel.load_game(self.config.game_string) + self._state = self._game.new_initial_state() + + # Agent = player 0; opponent = player 1 + self._agent_player: int = 0 + self._move_count: int = 0 + self._episode_id: str = "" + # cache of reconstructed grid (-1 empty, {0,1} owners) + self._grid_cache: Optional[np.ndarray] = None + + # ----------------------------- API ----------------------------- + + def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if seed is not None: + np.random.seed(seed) + self._state = self._game.new_initial_state() + self._move_count = 0 + self._episode_id = self._new_episode_id() + self._grid_cache = None + obs = self._build_observation(done=False, reward=0.0, info={"engine": "open_spiel"}) + return obs, self._build_state() + + def step(self, column: int) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Apply agent move (column 0..6). Optionally autoplay opponent move.""" + assert 0 <= column < self.COLS, f"column out of range: {column}" + + self._maybe_autoplay_until_agent_turn() + + # Map to OpenSpiel action; legality guard + act = self._column_to_action(column) + legal = self._state.legal_actions() + if act not in legal: + info = {"error": "illegal_action", "legal_columns": self.legal_actions()} + obs = self._build_observation(done=True, reward=-1.0, info=info) + return obs, self._build_state() + + self._state.apply_action(act) + self._move_count += 1 + self._invalidate_grid_cache() + + if self._state.is_terminal(): + reward = self._terminal_reward_for_agent() + obs = self._build_observation(done=True, reward=reward, info={"engine": "open_spiel"}) + return obs, self._build_state() + + if self.config.autoplay_opponent: + self._autoplay_opponent_once() + if self._state.is_terminal(): + reward = self._terminal_reward_for_agent() + obs = self._build_observation(done=True, reward=reward, info={"engine": "open_spiel"}) + return obs, self._build_state() + + obs = self._build_observation(done=False, reward=0.0, info={"engine": "open_spiel"}) + return obs, self._build_state() + + def close(self) -> None: + # No special cleanup required + self._state = self._game.new_initial_state() + self._grid_cache = None + + # --------------------------- helpers --------------------------- + + def legal_actions(self) -> List[int]: + return sorted({self._action_to_column(a) for a in self._state.legal_actions()}) + + def current_player(self) -> int: + return 1 if self._state.current_player() == self._agent_player else -1 + + def board_agent_view(self) -> np.ndarray: + """Return 6x7 board: 0 empty, +1 agent discs, -1 opponent discs.""" + grid = self._reconstruct_grid_from_history() + board = np.zeros_like(grid, dtype=int) + board[grid == -1] = 0 + board[grid == self._agent_player] = 1 + board[(grid != -1) & (grid != self._agent_player)] = -1 + return board + + def _reconstruct_grid_from_history(self) -> np.ndarray: + """Rebuild grid (-1 empty, 0/1 owners) from action history.""" + if self._grid_cache is not None: + return self._grid_cache + grid = np.zeros((self.ROWS, self.COLS), dtype=int) - 1 # -1 empty + player = 0 # starts with player 0 + for act in self._state.history(): + col = self._action_to_column(act) + rr = self._lowest_empty_row(grid, col) + if rr is not None: + grid[rr, col] = player + player = 1 - player + self._grid_cache = grid + return grid + + @staticmethod + def _lowest_empty_row(grid: np.ndarray, col: int) -> Optional[int]: + for r in range(grid.shape[0] - 1, -1, -1): + if grid[r, col] == -1: + return r + return None + + def _invalidate_grid_cache(self) -> None: + self._grid_cache = None + + # ----- action mapping ----- + + def _column_to_action(self, col: int) -> int: + # OpenSpiel uses 0..6 column IDs as actions + # still verify against legal action list in case of variant configs + for a in self._state.legal_actions(): + if self._action_to_column(a) == col: + return a + return col + + @staticmethod + def _action_to_column(action: int) -> int: + return int(action) + + # ----- opponent autoplay ----- + + def _maybe_autoplay_until_agent_turn(self) -> None: + if not self.config.autoplay_opponent: + return + while self._state.current_player() != self._agent_player and not self._state.is_terminal(): + self._autoplay_opponent_once() + + def _autoplay_opponent_once(self) -> None: + if self._state.current_player() == self._agent_player or self._state.is_terminal(): + return + legal = self._state.legal_actions() + if not legal: + return + cols = [self._action_to_column(a) for a in legal] + if self.config.opponent_policy == "lowest": + chosen_col = min(cols) + elif self.config.opponent_policy == "highest": + chosen_col = max(cols) + else: + chosen_col = int(np.random.choice(cols)) + self._state.apply_action(self._column_to_action(chosen_col)) + self._invalidate_grid_cache() + + # ----- rewards ----- + + def _terminal_reward_for_agent(self) -> float: + if not self._state.is_terminal(): + return 0.0 + returns = self._state.returns() + val = float(returns[self._agent_player]) # >0 win, <0 loss, 0 draw + if val > 0: + return 1.0 + if val < 0: + return -1.0 + return 0.0 + + # ----- payloads ----- + + def _build_observation(self, done: bool, reward: float, info: Dict[str, Any]) -> Dict[str, Any]: + board = self.board_agent_view() + obs = { + "board": board.tolist(), + "legal_actions": [] if done else self.legal_actions(), + "current_player": self.current_player() if not done else 1, + "last_move": self._last_move_column(), + "done": bool(done), + "reward": float(reward), + "info": dict(info or {}), + } + return obs + + def _build_state(self) -> Dict[str, Any]: + return { + "rows": self.ROWS, + "cols": self.COLS, + "move_count": self._move_count, + "episode_id": self._episode_id, + } + + def _last_move_column(self) -> Optional[int]: + hist = self._state.history() + if not hist: + return None + return self._action_to_column(hist[-1]) + + @staticmethod + def _new_episode_id() -> str: + import uuid + return str(uuid.uuid4())