Skip to content

Commit 6876aff

Browse files
committed
fix LLMObservationWrapper accumulate observations in self.full_observations across resets
1 parent 3be78d4 commit 6876aff

File tree

2 files changed

+96
-69
lines changed

2 files changed

+96
-69
lines changed

src/envs/textarena_env/server/environment.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
from core.env_server.interfaces import Environment
1818

19-
from ..models import TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState
19+
from ..models import (
20+
TextArenaAction,
21+
TextArenaMessage,
22+
TextArenaObservation,
23+
TextArenaState,
24+
)
2025
from ..rewards import RewardProvider, build_reward_providers
2126

2227

@@ -92,6 +97,18 @@ def __init__(
9297
# Environment interface
9398
# ------------------------------------------------------------------
9499
def reset(self) -> TextArenaObservation:
100+
# TextArena observation wrappers (LLMObservationWrapper, etc.) accumulate
101+
# observations in self.full_observations across resets. Since we can't modify TextArena,
102+
# we need to manually clear this state to prevent history accumulation.
103+
env = self._ta_env
104+
while hasattr(env, "env"):
105+
if hasattr(env, "full_observations"):
106+
env.full_observations = {}
107+
env = env.env
108+
# Also check the final unwrapped env
109+
if hasattr(env, "full_observations"):
110+
env.full_observations = {}
111+
95112
self._ta_env.reset(num_players=self.num_players)
96113

97114
for provider in self._reward_providers:
@@ -128,13 +145,18 @@ def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore
128145
observation.reward = reward
129146
self._state.last_reward = reward
130147

131-
reward_signals = self._compute_reward_signals(action=action, observation=observation)
148+
reward_signals = self._compute_reward_signals(
149+
action=action, observation=observation
150+
)
132151
if reward_signals:
133152
observation.info.setdefault("reward_signals", {}).update(reward_signals)
134153
observation.metadata.setdefault("reward_signals", {}).update(reward_signals)
135154
self._last_reward_signals = reward_signals
136155
if reward_signals:
137-
self._state.last_info = {**(self._state.last_info or {}), "reward_signals": reward_signals}
156+
self._state.last_info = {
157+
**(self._state.last_info or {}),
158+
"reward_signals": reward_signals,
159+
}
138160
self._state.raw_state = self._snapshot_state()
139161

140162
return observation
@@ -150,16 +172,30 @@ def _build_observation(self) -> TextArenaObservation:
150172
player_id, messages = self._ta_env.get_observation()
151173

152174
ta_messages = self._convert_messages(messages)
175+
176+
# Extract prompt from the appropriate messages.
177+
# TextArena PROMPT type messages contain the game instructions added during reset.
178+
# As a fallback for environments that don't use typed messages, use only the first
179+
# message if we're at turn 0 (fresh reset).
153180
prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"]
181+
154182
if not prompt_lines:
155-
# Fallback to most recent message history for prompt
156-
prompt_lines = [msg.content for msg in ta_messages]
183+
# Fallback: use the first message only if at turn 0 (just after reset)
184+
# DO NOT use all messages as this causes history accumulation
185+
current_turn = getattr(self._ta_env.state, "turn", 0)
186+
if current_turn == 0 and ta_messages:
187+
prompt_lines = [ta_messages[0].content]
188+
else:
189+
# Use env_id as final fallback to avoid including game history
190+
prompt_lines = [self.env_id]
191+
192+
prompt = "\n".join(prompt_lines).strip()
157193

158194
info: Dict[str, Any] = {}
159195
info.update(getattr(self._ta_env.state, "step_info", {}))
160196

161197
observation = TextArenaObservation(
162-
prompt="\n".join(prompt_lines).strip(),
198+
prompt=prompt,
163199
messages=ta_messages,
164200
current_player_id=player_id,
165201
legal_players=self._legal_players(),
@@ -182,29 +218,31 @@ def _build_observation(self) -> TextArenaObservation:
182218

183219
def _legal_players(self) -> List[int]:
184220
role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {}
185-
players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0]
221+
players = [
222+
pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0
223+
]
186224
return sorted(players)
187225

188226
def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]:
189227
converted: List[TextArenaMessage] = []
190-
buffered_content: List[str] = []
191228
buffered_sender: int | None = None
192229
buffered_category: str | None = None
193-
last_char_was_newline = False
230+
buffered_content: List[str] = []
194231

195232
def flush_buffer() -> None:
196233
nonlocal buffered_content, buffered_sender, buffered_category
197-
if buffered_content:
198-
converted.append(
199-
TextArenaMessage(
200-
sender_id=buffered_sender if buffered_sender is not None else -1,
201-
content="".join(buffered_content),
202-
category=buffered_category or "MESSAGE",
203-
)
234+
if not buffered_content:
235+
return
236+
converted.append(
237+
TextArenaMessage(
238+
sender_id=buffered_sender if buffered_sender is not None else -1,
239+
content="".join(buffered_content),
240+
category=buffered_category or "MESSAGE",
204241
)
242+
)
205243
buffered_content = []
206-
buffered_sender = None
207244
buffered_category = None
245+
buffered_sender = None
208246

209247
for entry in messages:
210248
if isinstance(entry, tuple) and len(entry) == 3:
@@ -219,29 +257,17 @@ def flush_buffer() -> None:
219257
sender_id = int(sender) if isinstance(sender, (int, float)) else -1
220258
text = str(content)
221259

222-
if text == "\n":
223-
flush_buffer()
224-
if last_char_was_newline:
225-
converted.append(
226-
TextArenaMessage(
227-
sender_id=sender_id,
228-
content="",
229-
category=category_name,
230-
)
231-
)
232-
last_char_was_newline = True
233-
continue
234-
235-
if buffered_sender is None or buffered_category is None:
236-
buffered_sender = sender_id
237-
buffered_category = category_name
238-
elif buffered_sender != sender_id or buffered_category != category_name:
260+
if (
261+
buffered_content
262+
and buffered_category == category_name
263+
and buffered_sender == sender_id
264+
):
265+
buffered_content.append(text)
266+
else:
239267
flush_buffer()
240268
buffered_sender = sender_id
241269
buffered_category = category_name
242-
243-
buffered_content.append(text)
244-
last_char_was_newline = False
270+
buffered_content = [text]
245271

246272
flush_buffer()
247273

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from envs.textarena_env.server.environment import TextArenaEnvironment
2-
from envs.textarena_env.models import TextArenaMessage
2+
from envs.textarena_env.models import TextArenaMessage, TextArenaAction
33

44

55
def test_convert_messages_coalesces_consecutive_characters():
@@ -23,42 +23,43 @@ def test_convert_messages_coalesces_consecutive_characters():
2323
]
2424

2525

26-
def test_convert_messages_splits_on_newlines():
27-
env = object.__new__(TextArenaEnvironment)
26+
def test_wordle_reset_clears_accumulated_state():
27+
"""Test that resetting Wordle environment clears accumulated observation state.
2828
29-
raw_messages = [
30-
"[",
31-
"G",
32-
"A",
33-
"M",
34-
"E",
35-
"]",
36-
"\n",
37-
"[",
38-
"N",
39-
"E",
40-
"X",
41-
"T",
42-
"]",
43-
]
29+
This test verifies the workaround for TextArena's LLMObservationWrapper,
30+
which accumulates observations in self.full_observations across resets.
31+
"""
32+
env = TextArenaEnvironment(
33+
env_id="Wordle-v0",
34+
num_players=1,
35+
)
4436

45-
converted = env._convert_messages(raw_messages)
37+
# First episode
38+
obs1 = env.reset()
39+
prompt1_len = len(obs1.prompt)
4640

47-
assert converted == [
48-
TextArenaMessage(sender_id=-1, content="[GAME]", category="MESSAGE"),
49-
TextArenaMessage(sender_id=-1, content="[NEXT]", category="MESSAGE"),
50-
]
41+
# Make a move to accumulate some state
42+
env.step(TextArenaAction(message="[CRANE]"))
5143

44+
# Second episode - should NOT accumulate from first episode
45+
obs2 = env.reset()
46+
prompt2_len = len(obs2.prompt)
5247

53-
def test_convert_messages_preserves_blank_lines():
54-
env = object.__new__(TextArenaEnvironment)
48+
# Make another move
49+
env.step(TextArenaAction(message="[STALE]"))
5550

56-
raw_messages = ["A", "\n", "\n", "B"]
51+
# Third episode - should NOT accumulate from previous episodes
52+
obs3 = env.reset()
53+
prompt3_len = len(obs3.prompt)
5754

58-
converted = env._convert_messages(raw_messages)
55+
# All prompts should be the same length (no accumulation)
56+
assert prompt1_len == prompt2_len, (
57+
f"Episode 2 accumulated state: {prompt1_len} -> {prompt2_len}"
58+
)
59+
assert prompt2_len == prompt3_len, (
60+
f"Episode 3 accumulated state: {prompt2_len} -> {prompt3_len}"
61+
)
5962

60-
assert converted == [
61-
TextArenaMessage(sender_id=-1, content="A", category="MESSAGE"),
62-
TextArenaMessage(sender_id=-1, content="", category="MESSAGE"),
63-
TextArenaMessage(sender_id=-1, content="B", category="MESSAGE"),
64-
]
63+
# Verify the prompts are actually the same content
64+
assert obs1.prompt == obs2.prompt
65+
assert obs2.prompt == obs3.prompt

0 commit comments

Comments
 (0)