Skip to content

Commit 0073606

Browse files
authored
Merge pull request #157 from kashif/textenv
[textarena] fix how consecutive messages from the same sender and category are handled
2 parents facb8e6 + f94d62c commit 0073606

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

src/envs/textarena_env/server/environment.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,25 @@ def _legal_players(self) -> List[int]:
187187

188188
def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]:
189189
converted: List[TextArenaMessage] = []
190+
buffered_sender: int | None = None
191+
buffered_category: str | None = None
192+
buffered_content: List[str] = []
193+
194+
def flush_buffer() -> None:
195+
nonlocal buffered_content, buffered_sender, buffered_category
196+
if not buffered_content:
197+
return
198+
converted.append(
199+
TextArenaMessage(
200+
sender_id=buffered_sender if buffered_sender is not None else -1,
201+
content="".join(buffered_content),
202+
category=buffered_category or "MESSAGE",
203+
)
204+
)
205+
buffered_content = []
206+
buffered_category = None
207+
buffered_sender = None
208+
190209
for entry in messages:
191210
if isinstance(entry, tuple) and len(entry) == 3:
192211
sender, content, category = entry
@@ -197,13 +216,18 @@ def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]:
197216
sender, content, category = -1, str(entry), "MESSAGE"
198217

199218
category_name = getattr(category, "name", str(category))
200-
converted.append(
201-
TextArenaMessage(
202-
sender_id=int(sender) if isinstance(sender, (int, float)) else -1,
203-
content=str(content),
204-
category=category_name,
205-
)
206-
)
219+
sender_id = int(sender) if isinstance(sender, (int, float)) else -1
220+
text = str(content)
221+
222+
if buffered_content and buffered_category == category_name and buffered_sender == sender_id:
223+
buffered_content.append(text)
224+
else:
225+
flush_buffer()
226+
buffered_sender = sender_id
227+
buffered_category = category_name
228+
buffered_content = [text]
229+
230+
flush_buffer()
207231

208232
return converted
209233

@@ -249,4 +273,3 @@ def _compute_reward_signals(
249273
for key, value in result.items():
250274
aggregated[key] = float(value)
251275
return aggregated
252-
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from envs.textarena_env.server.environment import TextArenaEnvironment
2+
from envs.textarena_env.models import TextArenaMessage
3+
4+
5+
def test_convert_messages_coalesces_consecutive_characters():
6+
env = object.__new__(TextArenaEnvironment)
7+
8+
raw_messages = [
9+
(0, "[", "PROMPT"),
10+
(0, "GAME", "PROMPT"),
11+
(0, "]", "PROMPT"),
12+
(1, "A", "MESSAGE"),
13+
(1, "B", "MESSAGE"),
14+
(2, "!", "MESSAGE"),
15+
]
16+
17+
converted = env._convert_messages(raw_messages)
18+
19+
assert converted == [
20+
TextArenaMessage(sender_id=0, content="[GAME]", category="PROMPT"),
21+
TextArenaMessage(sender_id=1, content="AB", category="MESSAGE"),
22+
TextArenaMessage(sender_id=2, content="!", category="MESSAGE"),
23+
]

0 commit comments

Comments
 (0)