1616
1717from core .env_server .interfaces import Environment
1818
19- from ..models import TextArenaAction , TextArenaMessage , TextArenaObservation , TextArenaState
19+ from ..models import (
20+ TextArenaAction ,
21+ TextArenaMessage ,
22+ TextArenaObservation ,
23+ TextArenaState ,
24+ )
2025from ..rewards import RewardProvider , build_reward_providers
2126
2227
@@ -92,6 +97,18 @@ def __init__(
9297 # Environment interface
9398 # ------------------------------------------------------------------
9499 def reset (self ) -> TextArenaObservation :
100+ # TextArena observation wrappers (LLMObservationWrapper, etc.) accumulate
101+ # observations in self.full_observations across resets. Since we can't modify TextArena,
102+ # we need to manually clear this state to prevent history accumulation.
103+ env = self ._ta_env
104+ while hasattr (env , "env" ):
105+ if hasattr (env , "full_observations" ):
106+ env .full_observations = {}
107+ env = env .env
108+ # Also check the final unwrapped env
109+ if hasattr (env , "full_observations" ):
110+ env .full_observations = {}
111+
95112 self ._ta_env .reset (num_players = self .num_players )
96113
97114 for provider in self ._reward_providers :
@@ -128,13 +145,18 @@ def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore
128145 observation .reward = reward
129146 self ._state .last_reward = reward
130147
131- reward_signals = self ._compute_reward_signals (action = action , observation = observation )
148+ reward_signals = self ._compute_reward_signals (
149+ action = action , observation = observation
150+ )
132151 if reward_signals :
133152 observation .info .setdefault ("reward_signals" , {}).update (reward_signals )
134153 observation .metadata .setdefault ("reward_signals" , {}).update (reward_signals )
135154 self ._last_reward_signals = reward_signals
136155 if reward_signals :
137- self ._state .last_info = {** (self ._state .last_info or {}), "reward_signals" : reward_signals }
156+ self ._state .last_info = {
157+ ** (self ._state .last_info or {}),
158+ "reward_signals" : reward_signals ,
159+ }
138160 self ._state .raw_state = self ._snapshot_state ()
139161
140162 return observation
@@ -150,16 +172,30 @@ def _build_observation(self) -> TextArenaObservation:
150172 player_id , messages = self ._ta_env .get_observation ()
151173
152174 ta_messages = self ._convert_messages (messages )
175+
176+ # Extract prompt from the appropriate messages.
177+ # TextArena PROMPT type messages contain the game instructions added during reset.
178+ # As a fallback for environments that don't use typed messages, use only the first
179+ # message if we're at turn 0 (fresh reset).
153180 prompt_lines = [msg .content for msg in ta_messages if msg .category == "PROMPT" ]
181+
154182 if not prompt_lines :
155- # Fallback to most recent message history for prompt
156- prompt_lines = [msg .content for msg in ta_messages ]
183+ # Fallback: use the first message only if at turn 0 (just after reset)
184+ # DO NOT use all messages as this causes history accumulation
185+ current_turn = getattr (self ._ta_env .state , "turn" , 0 )
186+ if current_turn == 0 and ta_messages :
187+ prompt_lines = [ta_messages [0 ].content ]
188+ else :
189+ # Use env_id as final fallback to avoid including game history
190+ prompt_lines = [self .env_id ]
191+
192+ prompt = "\n " .join (prompt_lines ).strip ()
157193
158194 info : Dict [str , Any ] = {}
159195 info .update (getattr (self ._ta_env .state , "step_info" , {}))
160196
161197 observation = TextArenaObservation (
162- prompt = " \n " . join ( prompt_lines ). strip () ,
198+ prompt = prompt ,
163199 messages = ta_messages ,
164200 current_player_id = player_id ,
165201 legal_players = self ._legal_players (),
@@ -182,29 +218,31 @@ def _build_observation(self) -> TextArenaObservation:
182218
183219 def _legal_players (self ) -> List [int ]:
184220 role_mapping = getattr (self ._ta_env .state , "role_mapping" , {}) or {}
185- players = [pid for pid in role_mapping .keys () if isinstance (pid , int ) and pid >= 0 ]
221+ players = [
222+ pid for pid in role_mapping .keys () if isinstance (pid , int ) and pid >= 0
223+ ]
186224 return sorted (players )
187225
188226 def _convert_messages (self , messages : Iterable [Any ]) -> List [TextArenaMessage ]:
189227 converted : List [TextArenaMessage ] = []
190- buffered_content : List [str ] = []
191228 buffered_sender : int | None = None
192229 buffered_category : str | None = None
193- last_char_was_newline = False
230+ buffered_content : List [ str ] = []
194231
195232 def flush_buffer () -> None :
196233 nonlocal buffered_content , buffered_sender , buffered_category
197- if buffered_content :
198- converted . append (
199- TextArenaMessage (
200- sender_id = buffered_sender if buffered_sender is not None else - 1 ,
201- content = "" . join ( buffered_content ) ,
202- category = buffered_category or "MESSAGE" ,
203- )
234+ if not buffered_content :
235+ return
236+ converted . append (
237+ TextArenaMessage (
238+ sender_id = buffered_sender if buffered_sender is not None else - 1 ,
239+ content = "" . join ( buffered_content ) ,
240+ category = buffered_category or "MESSAGE" ,
204241 )
242+ )
205243 buffered_content = []
206- buffered_sender = None
207244 buffered_category = None
245+ buffered_sender = None
208246
209247 for entry in messages :
210248 if isinstance (entry , tuple ) and len (entry ) == 3 :
@@ -219,29 +257,17 @@ def flush_buffer() -> None:
219257 sender_id = int (sender ) if isinstance (sender , (int , float )) else - 1
220258 text = str (content )
221259
222- if text == "\n " :
223- flush_buffer ()
224- if last_char_was_newline :
225- converted .append (
226- TextArenaMessage (
227- sender_id = sender_id ,
228- content = "" ,
229- category = category_name ,
230- )
231- )
232- last_char_was_newline = True
233- continue
234-
235- if buffered_sender is None or buffered_category is None :
236- buffered_sender = sender_id
237- buffered_category = category_name
238- elif buffered_sender != sender_id or buffered_category != category_name :
260+ if (
261+ buffered_content
262+ and buffered_category == category_name
263+ and buffered_sender == sender_id
264+ ):
265+ buffered_content .append (text )
266+ else :
239267 flush_buffer ()
240268 buffered_sender = sender_id
241269 buffered_category = category_name
242-
243- buffered_content .append (text )
244- last_char_was_newline = False
270+ buffered_content = [text ]
245271
246272 flush_buffer ()
247273
0 commit comments