Skip to content

Commit bcecb3b

Browse files
committed
implement basic textarena wrapper server
1 parent 0807289 commit bcecb3b

File tree

4 files changed

+315
-0
lines changed

4 files changed

+315
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Use the shared OpenEnv base image (Python 3.11)
8+
ARG BASE_IMAGE=openenv-base:latest
9+
FROM ${BASE_IMAGE}
10+
11+
# Install system libraries required by TextArena (cv2 needs libGL, glib)
12+
RUN apt-get update && apt-get install -y --no-install-recommends \
13+
libgl1 \
14+
libglib2.0-0 \
15+
&& rm -rf /var/lib/apt/lists/*
16+
17+
# Install TextArena and Python dependencies
18+
RUN pip install --no-cache-dir \
19+
textarena==0.6.1 \
20+
nltk==3.9.2
21+
22+
# Copy OpenEnv core and TextArena environment sources
23+
COPY src/core/ /app/src/core/
24+
COPY src/envs/textarena_env/ /app/src/envs/textarena_env/
25+
26+
# Optional: health check to ensure server responsiveness
27+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
28+
CMD curl -f http://localhost:8000/health || exit 1
29+
30+
# Run the TextArena FastAPI server
31+
CMD ["uvicorn", "envs.textarena_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
32+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Server components for the generic TextArena environment."""
8+
9+
from .environment import TextArenaEnvironment
10+
11+
__all__ = ["TextArenaEnvironment"]
12+
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""FastAPI application entrypoint for the TextArena environment."""
8+
9+
from __future__ import annotations
10+
11+
import os
12+
13+
from core.env_server.http_server import create_app
14+
15+
from ..models import TextArenaAction, TextArenaObservation
16+
from .environment import TextArenaEnvironment
17+
18+
19+
def _parse_env_kwargs(prefix: str = "TEXTARENA_KW_") -> dict[str, str]:
20+
"""Collect arbitrary environment kwargs from the process environment."""
21+
22+
env_kwargs: dict[str, str] = {}
23+
for key, value in os.environ.items():
24+
if key.startswith(prefix):
25+
env_key = key[len(prefix) :].lower()
26+
env_kwargs[env_key] = value
27+
return env_kwargs
28+
29+
30+
env_id = os.getenv("TEXTARENA_ENV_ID", "Wordle-v0")
31+
num_players = int(os.getenv("TEXTARENA_NUM_PLAYERS", "1"))
32+
max_turns_env = os.getenv("TEXTARENA_MAX_TURNS")
33+
max_turns = int(max_turns_env) if max_turns_env is not None else None
34+
download_nltk = os.getenv("TEXTARENA_DOWNLOAD_NLTK", "1") in {"1", "true", "True"}
35+
36+
extra_kwargs = _parse_env_kwargs()
37+
38+
environment = TextArenaEnvironment(
39+
env_id=env_id,
40+
num_players=num_players,
41+
max_turns=max_turns,
42+
download_nltk=download_nltk,
43+
env_kwargs=extra_kwargs,
44+
)
45+
46+
app = create_app(environment, TextArenaAction, TextArenaObservation, env_name="textarena_env")
47+
48+
49+
if __name__ == "__main__":
50+
import uvicorn
51+
52+
uvicorn.run(app, host="0.0.0.0", port=8000)
53+
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Server implementation for the generic TextArena environment."""
8+
9+
from __future__ import annotations
10+
11+
import sys
12+
from typing import Any, Dict, Iterable, List, Optional
13+
from uuid import uuid4
14+
15+
import nltk
16+
17+
from core.env_server.interfaces import Environment
18+
19+
from ..models import TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState
20+
21+
22+
_TEXTARENA_MODULE: Any | None = None
23+
_TEXTARENA_IMPORT_ERROR: Exception | None = None
24+
25+
26+
def _import_textarena() -> Any:
27+
"""Import ``textarena`` lazily and cache the module reference."""
28+
29+
global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR
30+
31+
if _TEXTARENA_MODULE is not None:
32+
return _TEXTARENA_MODULE
33+
34+
if _TEXTARENA_IMPORT_ERROR is not None:
35+
raise _TEXTARENA_IMPORT_ERROR
36+
37+
if sys.version_info < (3, 10):
38+
_TEXTARENA_IMPORT_ERROR = RuntimeError(
39+
"TextArena environments require Python 3.10 or newer; "
40+
f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}"
41+
)
42+
raise _TEXTARENA_IMPORT_ERROR
43+
44+
try:
45+
import textarena as ta # type: ignore[import]
46+
except Exception as exc: # pragma: no cover - surfaced to caller
47+
_TEXTARENA_IMPORT_ERROR = exc
48+
raise
49+
50+
_TEXTARENA_MODULE = ta
51+
return ta
52+
53+
54+
class TextArenaEnvironment(Environment):
55+
"""Wrap any TextArena game behind the OpenEnv ``Environment`` API."""
56+
57+
def __init__(
58+
self,
59+
env_id: str = "Wordle-v0",
60+
*,
61+
num_players: int = 1,
62+
max_turns: Optional[int] = None,
63+
download_nltk: bool = True,
64+
env_kwargs: Optional[Dict[str, Any]] = None,
65+
) -> None:
66+
super().__init__()
67+
68+
ta = _import_textarena()
69+
70+
if download_nltk:
71+
nltk.download("words", quiet=True)
72+
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
73+
74+
self.env_id = env_id
75+
self.num_players = num_players
76+
self.max_turns = max_turns
77+
self._env_kwargs = env_kwargs or {}
78+
79+
self._ta_env = ta.make(env_id=env_id, **self._env_kwargs)
80+
81+
self._state = TextArenaState(
82+
env_id=env_id,
83+
num_players=num_players,
84+
max_turns=max_turns,
85+
)
86+
87+
# ------------------------------------------------------------------
88+
# Environment interface
89+
# ------------------------------------------------------------------
90+
def reset(self) -> TextArenaObservation:
91+
self._ta_env.reset(num_players=self.num_players)
92+
93+
self._state.episode_id = str(uuid4())
94+
self._state.step_count = 0
95+
self._state.turn = 0
96+
self._state.last_reward = 0.0
97+
self._state.last_info = {}
98+
self._state.raw_state = self._snapshot_state()
99+
100+
observation = self._build_observation()
101+
observation.reward = 0.0
102+
observation.done = False
103+
104+
return observation
105+
106+
def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore[override]
107+
if not isinstance(action, TextArenaAction):
108+
raise TypeError(f"Expected TextArenaAction, received {type(action)!r}")
109+
110+
done, info = self._ta_env.step(action.message)
111+
112+
self._state.step_count += 1
113+
self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1)
114+
self._state.last_info = info or {}
115+
116+
observation = self._build_observation()
117+
observation.done = done
118+
119+
reward = self._extract_reward()
120+
observation.reward = reward
121+
self._state.last_reward = reward
122+
self._state.raw_state = self._snapshot_state()
123+
124+
return observation
125+
126+
@property
127+
def state(self) -> TextArenaState:
128+
return self._state
129+
130+
# ------------------------------------------------------------------
131+
# Helpers
132+
# ------------------------------------------------------------------
133+
def _build_observation(self) -> TextArenaObservation:
134+
player_id, messages = self._ta_env.get_observation()
135+
136+
ta_messages = self._convert_messages(messages)
137+
prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"]
138+
if not prompt_lines:
139+
# Fallback to most recent message history for prompt
140+
prompt_lines = [msg.content for msg in ta_messages]
141+
142+
info: Dict[str, Any] = {}
143+
info.update(getattr(self._ta_env.state, "step_info", {}))
144+
145+
observation = TextArenaObservation(
146+
prompt="\n".join(prompt_lines).strip(),
147+
messages=ta_messages,
148+
current_player_id=player_id,
149+
legal_players=self._legal_players(),
150+
info=info,
151+
metadata={
152+
"env_id": self.env_id,
153+
"turn": getattr(self._ta_env.state, "turn", 0),
154+
"raw_messages": [
155+
{
156+
"sender_id": msg.sender_id,
157+
"content": msg.content,
158+
"category": msg.category,
159+
}
160+
for msg in ta_messages
161+
],
162+
},
163+
)
164+
165+
return observation
166+
167+
def _legal_players(self) -> List[int]:
168+
role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {}
169+
players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0]
170+
return sorted(players)
171+
172+
def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]:
173+
converted: List[TextArenaMessage] = []
174+
for entry in messages:
175+
if isinstance(entry, tuple) and len(entry) == 3:
176+
sender, content, category = entry
177+
elif isinstance(entry, tuple) and len(entry) == 2:
178+
sender, content = entry
179+
category = "MESSAGE"
180+
else:
181+
sender, content, category = -1, str(entry), "MESSAGE"
182+
183+
category_name = getattr(category, "name", str(category))
184+
converted.append(
185+
TextArenaMessage(
186+
sender_id=int(sender) if isinstance(sender, (int, float)) else -1,
187+
content=str(content),
188+
category=category_name,
189+
)
190+
)
191+
192+
return converted
193+
194+
def _extract_reward(self) -> float:
195+
rewards = getattr(self._ta_env.state, "rewards", None)
196+
if isinstance(rewards, dict):
197+
# Use current player reward if available, otherwise default to player 0.
198+
player_id = getattr(self._ta_env.state, "current_player_id", 0)
199+
if player_id in rewards:
200+
return float(rewards[player_id])
201+
if 0 in rewards:
202+
return float(rewards[0])
203+
return 0.0
204+
205+
def _snapshot_state(self) -> Dict[str, Any]:
206+
state = self._ta_env.state
207+
snapshot: Dict[str, Any] = {
208+
"turn": getattr(state, "turn", 0),
209+
"game_state": getattr(state, "game_state", {}),
210+
"logs": list(getattr(state, "logs", [])),
211+
"rewards": getattr(state, "rewards", None),
212+
"done": getattr(state, "done", False),
213+
"role_mapping": getattr(state, "role_mapping", {}),
214+
"game_info": getattr(state, "game_info", {}),
215+
"step_info": getattr(state, "step_info", {}),
216+
}
217+
return snapshot
218+

0 commit comments

Comments
 (0)