|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Server implementation for the generic TextArena environment.""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import sys |
| 12 | +from typing import Any, Dict, Iterable, List, Optional |
| 13 | +from uuid import uuid4 |
| 14 | + |
| 15 | +import nltk |
| 16 | + |
| 17 | +from core.env_server.interfaces import Environment |
| 18 | + |
| 19 | +from ..models import TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState |
| 20 | + |
| 21 | + |
| 22 | +_TEXTARENA_MODULE: Any | None = None |
| 23 | +_TEXTARENA_IMPORT_ERROR: Exception | None = None |
| 24 | + |
| 25 | + |
| 26 | +def _import_textarena() -> Any: |
| 27 | + """Import ``textarena`` lazily and cache the module reference.""" |
| 28 | + |
| 29 | + global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR |
| 30 | + |
| 31 | + if _TEXTARENA_MODULE is not None: |
| 32 | + return _TEXTARENA_MODULE |
| 33 | + |
| 34 | + if _TEXTARENA_IMPORT_ERROR is not None: |
| 35 | + raise _TEXTARENA_IMPORT_ERROR |
| 36 | + |
| 37 | + if sys.version_info < (3, 10): |
| 38 | + _TEXTARENA_IMPORT_ERROR = RuntimeError( |
| 39 | + "TextArena environments require Python 3.10 or newer; " |
| 40 | + f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}" |
| 41 | + ) |
| 42 | + raise _TEXTARENA_IMPORT_ERROR |
| 43 | + |
| 44 | + try: |
| 45 | + import textarena as ta # type: ignore[import] |
| 46 | + except Exception as exc: # pragma: no cover - surfaced to caller |
| 47 | + _TEXTARENA_IMPORT_ERROR = exc |
| 48 | + raise |
| 49 | + |
| 50 | + _TEXTARENA_MODULE = ta |
| 51 | + return ta |
| 52 | + |
| 53 | + |
| 54 | +class TextArenaEnvironment(Environment): |
| 55 | + """Wrap any TextArena game behind the OpenEnv ``Environment`` API.""" |
| 56 | + |
| 57 | + def __init__( |
| 58 | + self, |
| 59 | + env_id: str = "Wordle-v0", |
| 60 | + *, |
| 61 | + num_players: int = 1, |
| 62 | + max_turns: Optional[int] = None, |
| 63 | + download_nltk: bool = True, |
| 64 | + env_kwargs: Optional[Dict[str, Any]] = None, |
| 65 | + ) -> None: |
| 66 | + super().__init__() |
| 67 | + |
| 68 | + ta = _import_textarena() |
| 69 | + |
| 70 | + if download_nltk: |
| 71 | + nltk.download("words", quiet=True) |
| 72 | + nltk.download("averaged_perceptron_tagger_eng", quiet=True) |
| 73 | + |
| 74 | + self.env_id = env_id |
| 75 | + self.num_players = num_players |
| 76 | + self.max_turns = max_turns |
| 77 | + self._env_kwargs = env_kwargs or {} |
| 78 | + |
| 79 | + self._ta_env = ta.make(env_id=env_id, **self._env_kwargs) |
| 80 | + |
| 81 | + self._state = TextArenaState( |
| 82 | + env_id=env_id, |
| 83 | + num_players=num_players, |
| 84 | + max_turns=max_turns, |
| 85 | + ) |
| 86 | + |
| 87 | + # ------------------------------------------------------------------ |
| 88 | + # Environment interface |
| 89 | + # ------------------------------------------------------------------ |
| 90 | + def reset(self) -> TextArenaObservation: |
| 91 | + self._ta_env.reset(num_players=self.num_players) |
| 92 | + |
| 93 | + self._state.episode_id = str(uuid4()) |
| 94 | + self._state.step_count = 0 |
| 95 | + self._state.turn = 0 |
| 96 | + self._state.last_reward = 0.0 |
| 97 | + self._state.last_info = {} |
| 98 | + self._state.raw_state = self._snapshot_state() |
| 99 | + |
| 100 | + observation = self._build_observation() |
| 101 | + observation.reward = 0.0 |
| 102 | + observation.done = False |
| 103 | + |
| 104 | + return observation |
| 105 | + |
| 106 | + def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore[override] |
| 107 | + if not isinstance(action, TextArenaAction): |
| 108 | + raise TypeError(f"Expected TextArenaAction, received {type(action)!r}") |
| 109 | + |
| 110 | + done, info = self._ta_env.step(action.message) |
| 111 | + |
| 112 | + self._state.step_count += 1 |
| 113 | + self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1) |
| 114 | + self._state.last_info = info or {} |
| 115 | + |
| 116 | + observation = self._build_observation() |
| 117 | + observation.done = done |
| 118 | + |
| 119 | + reward = self._extract_reward() |
| 120 | + observation.reward = reward |
| 121 | + self._state.last_reward = reward |
| 122 | + self._state.raw_state = self._snapshot_state() |
| 123 | + |
| 124 | + return observation |
| 125 | + |
| 126 | + @property |
| 127 | + def state(self) -> TextArenaState: |
| 128 | + return self._state |
| 129 | + |
| 130 | + # ------------------------------------------------------------------ |
| 131 | + # Helpers |
| 132 | + # ------------------------------------------------------------------ |
| 133 | + def _build_observation(self) -> TextArenaObservation: |
| 134 | + player_id, messages = self._ta_env.get_observation() |
| 135 | + |
| 136 | + ta_messages = self._convert_messages(messages) |
| 137 | + prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"] |
| 138 | + if not prompt_lines: |
| 139 | + # Fallback to most recent message history for prompt |
| 140 | + prompt_lines = [msg.content for msg in ta_messages] |
| 141 | + |
| 142 | + info: Dict[str, Any] = {} |
| 143 | + info.update(getattr(self._ta_env.state, "step_info", {})) |
| 144 | + |
| 145 | + observation = TextArenaObservation( |
| 146 | + prompt="\n".join(prompt_lines).strip(), |
| 147 | + messages=ta_messages, |
| 148 | + current_player_id=player_id, |
| 149 | + legal_players=self._legal_players(), |
| 150 | + info=info, |
| 151 | + metadata={ |
| 152 | + "env_id": self.env_id, |
| 153 | + "turn": getattr(self._ta_env.state, "turn", 0), |
| 154 | + "raw_messages": [ |
| 155 | + { |
| 156 | + "sender_id": msg.sender_id, |
| 157 | + "content": msg.content, |
| 158 | + "category": msg.category, |
| 159 | + } |
| 160 | + for msg in ta_messages |
| 161 | + ], |
| 162 | + }, |
| 163 | + ) |
| 164 | + |
| 165 | + return observation |
| 166 | + |
| 167 | + def _legal_players(self) -> List[int]: |
| 168 | + role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {} |
| 169 | + players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0] |
| 170 | + return sorted(players) |
| 171 | + |
| 172 | + def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]: |
| 173 | + converted: List[TextArenaMessage] = [] |
| 174 | + for entry in messages: |
| 175 | + if isinstance(entry, tuple) and len(entry) == 3: |
| 176 | + sender, content, category = entry |
| 177 | + elif isinstance(entry, tuple) and len(entry) == 2: |
| 178 | + sender, content = entry |
| 179 | + category = "MESSAGE" |
| 180 | + else: |
| 181 | + sender, content, category = -1, str(entry), "MESSAGE" |
| 182 | + |
| 183 | + category_name = getattr(category, "name", str(category)) |
| 184 | + converted.append( |
| 185 | + TextArenaMessage( |
| 186 | + sender_id=int(sender) if isinstance(sender, (int, float)) else -1, |
| 187 | + content=str(content), |
| 188 | + category=category_name, |
| 189 | + ) |
| 190 | + ) |
| 191 | + |
| 192 | + return converted |
| 193 | + |
| 194 | + def _extract_reward(self) -> float: |
| 195 | + rewards = getattr(self._ta_env.state, "rewards", None) |
| 196 | + if isinstance(rewards, dict): |
| 197 | + # Use current player reward if available, otherwise default to player 0. |
| 198 | + player_id = getattr(self._ta_env.state, "current_player_id", 0) |
| 199 | + if player_id in rewards: |
| 200 | + return float(rewards[player_id]) |
| 201 | + if 0 in rewards: |
| 202 | + return float(rewards[0]) |
| 203 | + return 0.0 |
| 204 | + |
| 205 | + def _snapshot_state(self) -> Dict[str, Any]: |
| 206 | + state = self._ta_env.state |
| 207 | + snapshot: Dict[str, Any] = { |
| 208 | + "turn": getattr(state, "turn", 0), |
| 209 | + "game_state": getattr(state, "game_state", {}), |
| 210 | + "logs": list(getattr(state, "logs", [])), |
| 211 | + "rewards": getattr(state, "rewards", None), |
| 212 | + "done": getattr(state, "done", False), |
| 213 | + "role_mapping": getattr(state, "role_mapping", {}), |
| 214 | + "game_info": getattr(state, "game_info", {}), |
| 215 | + "step_info": getattr(state, "step_info", {}), |
| 216 | + } |
| 217 | + return snapshot |
| 218 | + |
0 commit comments