Skip to content

Commit 941da1a

Browse files
committed
add wordle specific rewards to the environment
1 parent bc8204d commit 941da1a

File tree

4 files changed

+173
-3
lines changed

4 files changed

+173
-3
lines changed

src/envs/textarena_env/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
TextArenaObservation,
1414
TextArenaState,
1515
)
16+
from .rewards import RewardProvider, build_reward_providers
1617

1718
__all__ = [
1819
"TextArenaEnv",
1920
"TextArenaAction",
2021
"TextArenaObservation",
2122
"TextArenaState",
2223
"TextArenaMessage",
24+
"RewardProvider",
25+
"build_reward_providers",
2326
]
24-

src/envs/textarena_env/rewards.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Reward provider utilities for TextArena environments."""
2+
3+
from __future__ import annotations
4+
5+
import re
6+
from typing import Dict, List, Protocol, Tuple
7+
8+
from .models import TextArenaAction, TextArenaObservation
9+
10+
11+
class RewardProvider(Protocol):
12+
"""Interface for computing auxiliary reward signals."""
13+
14+
def reset(self) -> None:
15+
"""Clear any internal state before a new episode."""
16+
17+
def compute(
18+
self, *, action: TextArenaAction, observation: TextArenaObservation
19+
) -> Dict[str, float]:
20+
"""Return a mapping of reward names to float values for the step."""
21+
22+
23+
def build_reward_providers(env_id: str) -> List[RewardProvider]:
24+
"""Instantiate reward providers appropriate for the given environment."""
25+
26+
providers: List[RewardProvider] = []
27+
if env_id == "Wordle-v0":
28+
providers.append(_WordleRewardProvider())
29+
return providers
30+
31+
32+
_WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]")
33+
34+
35+
def extract_guess(text: str) -> str:
36+
"""Normalize a Wordle guess string from arbitrary text."""
37+
38+
match = _WORDLE_GUESS_PATTERN.search(text)
39+
if match:
40+
return match.group(0).lower()
41+
42+
cleaned = re.sub(r"[^a-z]", "", text.lower())
43+
if len(cleaned) >= 5:
44+
return f"[{cleaned[:5]}]"
45+
return "[dunno]"
46+
47+
48+
def extract_wordle_feedback(observation: TextArenaObservation) -> str:
49+
"""Pull the latest feedback text from a Wordle observation."""
50+
51+
for message in reversed(observation.messages):
52+
content = message.content.strip()
53+
if "Feedback:" in content:
54+
return content.split("Feedback:", 1)[-1].strip()
55+
return ""
56+
57+
58+
def extract_feedback_counts(feedback: str) -> Tuple[int, int]:
59+
"""Return counts of green (G) and yellow (Y) markers from feedback."""
60+
61+
if not feedback:
62+
return (0, 0)
63+
64+
segments = [
65+
segment.strip() for segment in feedback.split("\n\n") if segment.strip()
66+
]
67+
if not segments:
68+
return (0, 0)
69+
70+
latest_segment = segments[-1]
71+
lines = [line.strip() for line in latest_segment.splitlines() if line.strip()]
72+
latest_line = lines[-1] if lines else latest_segment
73+
74+
green_count = latest_line.count("G")
75+
yellow_count = latest_line.count("Y")
76+
return (green_count, yellow_count)
77+
78+
79+
class _WordleRewardProvider:
80+
"""Reward provider that mirrors the GRPO Wordle heuristics."""
81+
82+
SIGNAL_MAP = {
83+
"greens": "wordle.greens",
84+
"yellows": "wordle.yellows",
85+
"repetitions": "wordle.repetitions",
86+
"correct": "wordle.correct",
87+
}
88+
89+
def __init__(self) -> None:
90+
self._guess_history: Dict[str, int] = {}
91+
92+
def reset(self) -> None:
93+
self._guess_history.clear()
94+
95+
def compute(
96+
self, *, action: TextArenaAction, observation: TextArenaObservation
97+
) -> Dict[str, float]:
98+
guess = extract_guess(action.message)
99+
feedback = extract_wordle_feedback(observation)
100+
101+
normalized_guess = guess if guess and guess != "[dunno]" else ""
102+
previous_occurrences = (
103+
self._guess_history.get(normalized_guess, 0) if normalized_guess else 0
104+
)
105+
106+
green_score = 0.0
107+
yellow_score = 0.0
108+
if feedback:
109+
green_count, yellow_count = extract_feedback_counts(feedback)
110+
green_score = green_count / 5.0
111+
yellow_score = yellow_count / 5.0
112+
113+
repetition_score = 1.0 - previous_occurrences
114+
correct_score = float(observation.reward or 0.0)
115+
116+
if normalized_guess:
117+
self._guess_history[normalized_guess] = previous_occurrences + 1
118+
119+
return {
120+
self.SIGNAL_MAP["greens"]: float(green_score),
121+
self.SIGNAL_MAP["yellows"]: float(yellow_score),
122+
self.SIGNAL_MAP["repetitions"]: float(repetition_score),
123+
self.SIGNAL_MAP["correct"]: float(correct_score),
124+
}
125+
126+
127+
__all__ = [
128+
"RewardProvider",
129+
"build_reward_providers",
130+
"extract_feedback_counts",
131+
"extract_guess",
132+
"extract_wordle_feedback",
133+
]

src/envs/textarena_env/server/environment.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from core.env_server.interfaces import Environment
1818

1919
from ..models import TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState
20+
from ..rewards import RewardProvider, build_reward_providers
2021

2122

2223
_TEXTARENA_MODULE: Any | None = None
@@ -84,18 +85,25 @@ def __init__(
8485
max_turns=max_turns,
8586
)
8687

88+
self._reward_providers: List[RewardProvider] = build_reward_providers(env_id)
89+
self._last_reward_signals: Dict[str, float] = {}
90+
8791
# ------------------------------------------------------------------
8892
# Environment interface
8993
# ------------------------------------------------------------------
9094
def reset(self) -> TextArenaObservation:
9195
self._ta_env.reset(num_players=self.num_players)
9296

97+
for provider in self._reward_providers:
98+
provider.reset()
99+
93100
self._state.episode_id = str(uuid4())
94101
self._state.step_count = 0
95102
self._state.turn = 0
96103
self._state.last_reward = 0.0
97104
self._state.last_info = {}
98105
self._state.raw_state = self._snapshot_state()
106+
self._last_reward_signals = {}
99107

100108
observation = self._build_observation()
101109
observation.reward = 0.0
@@ -119,6 +127,14 @@ def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore
119127
reward = self._extract_reward()
120128
observation.reward = reward
121129
self._state.last_reward = reward
130+
131+
reward_signals = self._compute_reward_signals(action=action, observation=observation)
132+
if reward_signals:
133+
observation.info.setdefault("reward_signals", {}).update(reward_signals)
134+
observation.metadata.setdefault("reward_signals", {}).update(reward_signals)
135+
self._last_reward_signals = reward_signals
136+
if reward_signals:
137+
self._state.last_info = {**(self._state.last_info or {}), "reward_signals": reward_signals}
122138
self._state.raw_state = self._snapshot_state()
123139

124140
return observation
@@ -214,5 +230,23 @@ def _snapshot_state(self) -> Dict[str, Any]:
214230
"game_info": getattr(state, "game_info", {}),
215231
"step_info": getattr(state, "step_info", {}),
216232
}
233+
if self._last_reward_signals:
234+
snapshot["reward_signals"] = dict(self._last_reward_signals)
217235
return snapshot
218236

237+
def _compute_reward_signals(
238+
self, *, action: TextArenaAction, observation: TextArenaObservation
239+
) -> Dict[str, float]:
240+
if not self._reward_providers:
241+
return {}
242+
243+
aggregated: Dict[str, float] = {}
244+
for provider in self._reward_providers:
245+
try:
246+
result = provider.compute(action=action, observation=observation)
247+
except Exception: # pragma: no cover - defensive
248+
continue
249+
for key, value in result.items():
250+
aggregated[key] = float(value)
251+
return aggregated
252+
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
TEXTARENA_ENV_ID="Wordle-v0" TEXTARENA_NUM_PLAYERS=2
1+
export TEXTARENA_ENV_ID="Wordle-v0"
2+
export TEXTARENA_NUM_PLAYERS=1
23

34
# Run the server
4-
exec uvicorn envs.textarena_env.server.app:app --host 0.0.0.0 --port 8000
5+
exec uvicorn envs.textarena_env.server.app:app --host 0.0.0.0 --port 8001
56

67

0 commit comments

Comments
 (0)