Skip to content

Commit 952173e

Browse files
committed
implement basic text arena client
1 parent bcecb3b commit 952173e

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

src/envs/textarena_env/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""TextArena environment integration for OpenEnv."""
8+
9+
from .client import TextArenaEnv
10+
from .models import (
11+
TextArenaAction,
12+
TextArenaMessage,
13+
TextArenaObservation,
14+
TextArenaState,
15+
)
16+
17+
__all__ = [
18+
"TextArenaEnv",
19+
"TextArenaAction",
20+
"TextArenaObservation",
21+
"TextArenaState",
22+
"TextArenaMessage",
23+
]
24+

src/envs/textarena_env/client.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""HTTP client for the generic TextArena environment."""
8+
9+
from __future__ import annotations
10+
11+
from typing import Any, Dict, TYPE_CHECKING
12+
13+
from core.client_types import StepResult
14+
from core.http_env_client import HTTPEnvClient
15+
16+
from .models import (
17+
TextArenaAction,
18+
TextArenaMessage,
19+
TextArenaObservation,
20+
TextArenaState,
21+
)
22+
23+
if TYPE_CHECKING:
24+
from core.containers.runtime import ContainerProvider
25+
26+
27+
class TextArenaEnv(HTTPEnvClient[TextArenaAction, TextArenaObservation]):
28+
"""HTTP client for the TextArena environment server."""
29+
30+
def _step_payload(self, action: TextArenaAction) -> Dict[str, Any]:
31+
return {"message": action.message}
32+
33+
def _parse_result(
34+
self, payload: Dict[str, Any]
35+
) -> StepResult[TextArenaObservation]:
36+
obs_data = payload.get("observation", {})
37+
messages_payload = obs_data.get("messages", [])
38+
messages = [
39+
TextArenaMessage(
40+
sender_id=item.get("sender_id", -1),
41+
content=item.get("content", ""),
42+
category=item.get("category", "MESSAGE"),
43+
)
44+
for item in messages_payload
45+
if isinstance(item, dict)
46+
]
47+
48+
observation = TextArenaObservation(
49+
prompt=obs_data.get("prompt", ""),
50+
messages=messages,
51+
current_player_id=obs_data.get("current_player_id", 0),
52+
legal_players=obs_data.get("legal_players", []),
53+
info=obs_data.get("info", {}),
54+
reward=payload.get("reward"),
55+
done=payload.get("done", False),
56+
metadata=obs_data.get("metadata", {}),
57+
)
58+
return StepResult(
59+
observation=observation,
60+
reward=payload.get("reward"),
61+
done=payload.get("done", False),
62+
)
63+
64+
def _parse_state(self, payload: Dict[str, Any]) -> TextArenaState:
65+
return TextArenaState(
66+
episode_id=payload.get("episode_id"),
67+
step_count=payload.get("step_count", 0),
68+
env_id=payload.get("env_id", "unknown"),
69+
num_players=payload.get("num_players", 1),
70+
max_turns=payload.get("max_turns"),
71+
turn=payload.get("turn", 0),
72+
last_reward=payload.get("last_reward", 0.0),
73+
last_info=payload.get("last_info", {}),
74+
raw_state=payload.get("raw_state", {}),
75+
)
76+

src/envs/textarena_env/models.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Common data models for the TextArena environment wrapper."""
8+
9+
from __future__ import annotations
10+
11+
from dataclasses import dataclass, field
12+
from typing import Any, Dict, List, Optional
13+
14+
from core.env_server.types import Action, Observation, State
15+
16+
17+
@dataclass
18+
class TextArenaMessage:
19+
"""Single message observed by a player."""
20+
21+
sender_id: int
22+
content: str
23+
category: str
24+
25+
26+
@dataclass(kw_only=True)
27+
class TextArenaAction(Action):
28+
"""Action issued by the agent for TextArena games."""
29+
30+
message: str
31+
32+
33+
@dataclass(kw_only=True)
34+
class TextArenaObservation(Observation):
35+
"""Observation returned from any TextArena game."""
36+
37+
prompt: str
38+
messages: List[TextArenaMessage] = field(default_factory=list)
39+
current_player_id: int = 0
40+
legal_players: List[int] = field(default_factory=list)
41+
info: Dict[str, Any] = field(default_factory=dict)
42+
43+
44+
@dataclass(kw_only=True)
45+
class TextArenaState(State):
46+
"""Structured state snapshot for the server."""
47+
48+
env_id: str
49+
num_players: int
50+
max_turns: Optional[int] = None
51+
turn: int = 0
52+
last_reward: float = 0.0
53+
last_info: Dict[str, Any] = field(default_factory=dict)
54+
raw_state: Dict[str, Any] = field(default_factory=dict)
55+

0 commit comments

Comments
 (0)