Skip to content

Commit 40c886c

Browse files
1stprinciplelistar2000xzrderek
authored
Integrate Eval Protocol as RL environment (#276)
* fix issue #259 * init * print(f"Train dataset size: {len(train_dataset)}") * super().__init__(rollout_engine, **kwargs) * qwen2p5-vl-32b-instruct * comment * print * [trajectory] * 100 tasks * 10 * n_parallel_tasks = 10 * update to qwen3 and new server * thinking * add train_frozenlake_flow * model_id = "accounts/pyroworks/deployedModels/qwen3-8b-g0m657sn" * print(episode.trajectories) * from datasets import load_dataset * fixed concurrency * tracing_signal * signal * trainer.val_before_train=False \ * _build_rollout_processor_config * print("model in frozen_lake_flow", model) * semaphore=self._rollout_processor_semaphore, * print(error_message) * update self.model to deployedModel * remove some codes * formatting --------- Co-authored-by: listar2000 <lisd.star2015@outlook.com> Co-authored-by: Derek Xu <xzrderek@gmail.com>
1 parent b153d3e commit 40c886c

File tree

6 files changed

+499
-2
lines changed

6 files changed

+499
-2
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
This workflow bridges eval-protocol's MCPGymRolloutProcessor with rllm-fw's Workflow pattern
3+
for the FrozenLake environment.
4+
"""
5+
6+
import asyncio
7+
from pathlib import Path
8+
9+
import eval_protocol
10+
from eval_protocol.benchmarks.test_frozen_lake import test_frozen_lake_evaluation
11+
from eval_protocol.models import EvaluationRow, InputMetadata, Message
12+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import (
13+
MCPGymRolloutProcessor,
14+
)
15+
from eval_protocol.pytest.types import RolloutProcessorConfig
16+
17+
from rllm.agents.agent import Episode, Step, Trajectory
18+
from rllm.engine.rollout.openai_engine import OpenAIEngine
19+
from rllm.workflows.workflow import Workflow
20+
21+
22+
class FrozenLakeWorkflow(Workflow):
23+
"""
24+
Workflow that executes frozen lake tasks using MCPGymRolloutProcessor.
25+
26+
Task format expected:
27+
{
28+
"id": "frozen_lake_task_0",
29+
"system_prompt": "...",
30+
"environment_context": {...},
31+
"user_prompt_template": "{observation}"
32+
}
33+
"""
34+
35+
# Class variables (shared across all workflow instances)
36+
_shared_server_started = False
37+
_server_lock = asyncio.Lock()
38+
_shared_rollout_processor = MCPGymRolloutProcessor()
39+
40+
def __init__(self, rollout_engine: OpenAIEngine, lite_llm_prefix: str = "fireworks_ai/", max_steps: int = 30, temperature: float = 1.0, max_tokens: int = 4096, **kwargs):
41+
super().__init__(rollout_engine, **kwargs)
42+
43+
self._rollout_processor_server_started = False
44+
self._rollout_processor_semaphore = asyncio.Semaphore(1)
45+
self._lite_llm_prefix = lite_llm_prefix
46+
self._temperature = temperature
47+
self._max_tokens = max_tokens
48+
self._max_steps = max_steps
49+
50+
eval_protocol_path = Path(eval_protocol.__file__).parent
51+
self._server_script_path = eval_protocol_path / "mcp_servers" / "frozen_lake" / "server.py"
52+
53+
# Use shared rollout processor across all instances
54+
self.rollout_processor = FrozenLakeWorkflow._shared_rollout_processor
55+
56+
def _build_rollout_processor_config(self):
57+
model = self._lite_llm_prefix + self.rollout_engine.model
58+
print("model in frozen_lake_flow", model)
59+
return RolloutProcessorConfig(
60+
completion_params={
61+
"model": model,
62+
"temperature": self._temperature,
63+
"max_tokens": self._max_tokens,
64+
},
65+
mcp_config_path="",
66+
server_script_path=str(self._server_script_path),
67+
steps=self._max_steps,
68+
semaphore=self._rollout_processor_semaphore,
69+
kwargs={"start_server": self._rollout_processor_server_started},
70+
)
71+
72+
async def run(self, task: dict, uid: str, **kwargs) -> Episode:
73+
"""
74+
Execute the frozen lake workflow.
75+
76+
Args:
77+
task: Dict containing frozen lake task data
78+
uid: Unique identifier for this episode
79+
**kwargs: Additional arguments
80+
81+
Returns:
82+
Episode with trajectory and computed rewards
83+
"""
84+
# Thread-safe server startup (double-checked locking pattern)
85+
if not FrozenLakeWorkflow._shared_server_started:
86+
# Only acquire lock if server not started yet
87+
async with FrozenLakeWorkflow._server_lock:
88+
# Check again inside lock (another workflow might have started it)
89+
if not FrozenLakeWorkflow._shared_server_started:
90+
# First workflow to reach here starts the server
91+
self._rollout_processor_server_started = True
92+
FrozenLakeWorkflow._shared_server_started = True
93+
else:
94+
self._rollout_processor_server_started = False
95+
else:
96+
self._rollout_processor_server_started = False
97+
98+
self.reset(task=task, uid=uid)
99+
100+
try:
101+
eval_row = self._task_to_evaluation_row(task)
102+
103+
tasks = self.rollout_processor([eval_row], self._build_rollout_processor_config())
104+
105+
if not tasks:
106+
raise ValueError("MCPGymRolloutProcessor returned no tasks")
107+
108+
result_row: EvaluationRow = await tasks[0]
109+
110+
episode = await self._evaluate_and_create_episode(result_row, task, uid)
111+
112+
return episode
113+
114+
except Exception as e:
115+
# Gracefully handle failures - return a failed Episode instead of crashing
116+
print(f"⚠️ Task {uid} failed: {e}")
117+
118+
failed_episode = Episode(
119+
id=uid,
120+
task=task,
121+
is_correct=False,
122+
trajectories=[],
123+
metrics={"frozen_lake_reward": 0.0, "error": str(e)},
124+
)
125+
return failed_episode
126+
127+
def _task_to_evaluation_row(self, task: dict) -> EvaluationRow:
128+
"""Convert rllm task dict to eval protocol EvaluationRow."""
129+
return EvaluationRow(
130+
messages=[Message(role="system", content=task["system_prompt"])],
131+
input_metadata=InputMetadata(
132+
row_id=task["id"],
133+
dataset_info={
134+
"environment_context": task["environment_context"],
135+
"user_prompt_template": task["user_prompt_template"],
136+
},
137+
),
138+
)
139+
140+
async def _evaluate_and_create_episode(
141+
self,
142+
row: EvaluationRow,
143+
task: dict,
144+
uid: str,
145+
) -> Episode:
146+
"""
147+
Evaluate the rollout and convert to rllm Episode.
148+
"""
149+
# Call the evaluation function
150+
evaluated_row: EvaluationRow = await test_frozen_lake_evaluation(row)
151+
152+
# Extract reward and metrics from evaluation_result
153+
if evaluated_row.evaluation_result is None:
154+
raise ValueError("Evaluation function did not return a result")
155+
156+
reward = evaluated_row.evaluation_result.score
157+
reward_info = evaluated_row.evaluation_result.metrics or {}
158+
159+
def msg_to_dict(msg: Message) -> dict:
160+
"""Convert eval_protocol Message to chat completion dict."""
161+
d = {"role": msg.role, "content": msg.content}
162+
if msg.tool_calls:
163+
d["tool_calls"] = [
164+
{
165+
"id": tc.id,
166+
"type": tc.type,
167+
"function": {
168+
"name": tc.function.name,
169+
"arguments": tc.function.arguments,
170+
},
171+
}
172+
for tc in msg.tool_calls
173+
]
174+
if msg.tool_call_id:
175+
d["tool_call_id"] = msg.tool_call_id
176+
if msg.name:
177+
d["name"] = msg.name
178+
return d
179+
180+
trajectory = Trajectory()
181+
all_messages = []
182+
183+
for msg in row.messages:
184+
msg_dict = msg_to_dict(msg)
185+
all_messages.append(msg_dict)
186+
187+
# Create Step with only observation and chat_completions for user or tool message
188+
if msg.role in ["user", "tool"]:
189+
new_step = Step(observation=str(msg.content or ""), chat_completions=all_messages.copy())
190+
trajectory.steps.append(new_step)
191+
192+
# Create new Step with action/response for assistant message
193+
elif msg.role == "assistant":
194+
# Extract action: tool calls if present, otherwise message content
195+
action_data = msg_dict.get("tool_calls") if msg.tool_calls else str(msg.content or "")
196+
197+
new_step = Step(
198+
model_response=str(msg.content) if msg.content else "",
199+
action=action_data,
200+
chat_completions=all_messages.copy(),
201+
)
202+
trajectory.steps.append(new_step)
203+
204+
# Assign final reward to the last step (sparse reward)
205+
if trajectory.steps:
206+
trajectory.steps[-1].reward = reward
207+
trajectory.steps[-1].info = reward_info
208+
209+
trajectory.reward = reward
210+
trajectory.task = task
211+
212+
# Create episode
213+
episode = Episode(
214+
id=uid,
215+
task=task,
216+
is_correct=(reward == 1.0),
217+
trajectories=[trajectory],
218+
metrics={"frozen_lake_reward": reward, **reward_info},
219+
)
220+
221+
return episode
222+
223+
def cleanup(self):
224+
"""Cleanup MCP server resources."""
225+
if self.rollout_processor:
226+
self.rollout_processor.cleanup()
227+
self.rollout_processor = None
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import random
2+
3+
from datasets import Dataset
4+
5+
from rllm.data.dataset import DatasetRegistry
6+
7+
8+
def prepare_frozen_lake_data(train_size: int, test_size: int):
9+
system_prompt = "You are playing FrozenLake, a grid-based navigation game displayed as a 4x4 text grid. The grid contains: S (Start), F (Frozen safe), H (Hole - deadly), G (Goal). You start at position S and must reach G while avoiding H tiles. In this version, the surface is not slippery so your moves are deterministic. IMPORTANT: When you are at the starting position, you appear as 'S'. When you move to other positions, the hightlighted position will change on the grid. If you step on H, the episode ends with failure. Use the lake_move tool with actions LEFT, DOWN, RIGHT, UP to navigate the grid."
10+
user_prompt_template = "Current game state grid:\n{observation}\n\nYou are navigating the 4x4 grid above. Navigate safely to reach the goal 'G' while avoiding holes 'H'. Choose your next move from: LEFT, DOWN, RIGHT, or UP."
11+
12+
def create_row(idx, seed):
13+
return {"id": f"run_{idx}", "system_prompt": system_prompt, "user_prompt_template": user_prompt_template, "environment_context": {"game": "FrozenLake", "map_name": "4x4", "seed": seed}}
14+
15+
seeds = random.sample(range(1, 1_000_001), train_size + test_size)
16+
all_rows = []
17+
for i in range(train_size + test_size):
18+
all_rows.append(create_row(i, seeds[i]))
19+
train_rows = all_rows[:train_size]
20+
test_rows = all_rows[train_size:]
21+
22+
train_dataset = Dataset.from_list(train_rows)
23+
test_dataset = Dataset.from_list(test_rows)
24+
25+
DatasetRegistry.register_dataset("frozen_lake_eval_protocol", train_dataset, "train")
26+
DatasetRegistry.register_dataset("frozen_lake_eval_protocol", test_dataset, "test")
27+
28+
print(f"Train dataset size: {len(train_dataset)}")
29+
print(f"Test dataset size: {len(test_dataset)}")
30+
31+
32+
if __name__ == "__main__":
33+
prepare_frozen_lake_data(train_size=100, test_size=100)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Run Frozen Lake Workflow with rllm-fw
3+
4+
This script demonstrates how to execute frozen lake tasks using rllm-fw's
5+
AgentWorkflowEngine with eval-protocol's MCPGymRolloutProcessor.
6+
"""
7+
8+
import asyncio
9+
import json
10+
import os
11+
from pathlib import Path
12+
13+
from frozen_lake_flow import FrozenLakeWorkflow
14+
15+
from rllm.data.dataset import DatasetRegistry
16+
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
17+
from rllm.engine.rollout.openai_engine import OpenAIEngine
18+
19+
20+
def evaluate_results(episodes):
21+
"""
22+
Evaluate the results and compute accuracy metrics.
23+
24+
Args:
25+
episodes: List of Episode objects
26+
"""
27+
total = len(episodes)
28+
correct = sum(1 for ep in episodes if ep.is_correct)
29+
accuracy = correct / total if total > 0 else 0.0
30+
31+
print("\n" + "=" * 60)
32+
print("EVALUATION RESULTS")
33+
print("=" * 60)
34+
print(f"Total tasks: {total}")
35+
print(f"Correct: {correct}")
36+
print(f"Accuracy: {accuracy:.2%}")
37+
print()
38+
39+
for episode in episodes:
40+
status = "✅" if episode.is_correct else "❌"
41+
reward = episode.metrics.get("frozen_lake_reward", 0.0)
42+
print(f"{status} Task {episode.id}: reward={reward:.3f}")
43+
44+
print("=" * 60)
45+
46+
return accuracy
47+
48+
49+
async def main():
50+
"""Main execution function."""
51+
52+
n_parallel_tasks = 4
53+
max_tasks = 4
54+
model_id = "accounts/pyroworks/deployedModels/qwen3-8b-g0m657sn"
55+
56+
# Create dummy rollout_engine (required by Workflow base class but not used)
57+
rollout_engine = OpenAIEngine(
58+
model=model_id,
59+
base_url="https://api.fireworks.ai/inference/v1",
60+
api_key=os.getenv("FIREWORKS_API_KEY"),
61+
)
62+
63+
engine = AgentWorkflowEngine(
64+
workflow_cls=FrozenLakeWorkflow,
65+
workflow_args={
66+
"lite_llm_prefix": "fireworks_ai/",
67+
"steps": 30,
68+
"temperature": 1.0,
69+
"max_tokens": 16384,
70+
},
71+
rollout_engine=rollout_engine,
72+
n_parallel_tasks=n_parallel_tasks,
73+
retry_limit=1,
74+
)
75+
76+
test_dataset = DatasetRegistry.load_dataset("frozen_lake_eval_protocol", "test")
77+
tasks = []
78+
for i in range(max_tasks):
79+
tasks.append(test_dataset[i])
80+
81+
print("Starting frozen lake workflow execution...")
82+
print(f"Model: {model_id}")
83+
print(f"Parallel tasks: {n_parallel_tasks}")
84+
print()
85+
86+
try:
87+
episodes = await engine.execute_tasks(tasks)
88+
for episode in episodes:
89+
print(episode.trajectories)
90+
accuracy = evaluate_results(episodes)
91+
92+
output_dir = Path("logs")
93+
output_dir.mkdir(exist_ok=True)
94+
output_file = output_dir / "frozen_lake_results.json"
95+
96+
with open(output_file, "w") as f:
97+
json.dump([episode.to_dict() for episode in episodes], f, indent=2)
98+
99+
print(f"\n✅ Results saved to {output_file}")
100+
101+
return accuracy
102+
103+
except Exception as e:
104+
print(f"❌ Error during execution: {e}")
105+
import traceback
106+
107+
traceback.print_exc()
108+
raise
109+
finally:
110+
engine.shutdown()
111+
112+
113+
if __name__ == "__main__":
114+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
115+
116+
accuracy = asyncio.run(main())
117+
118+
print(f"\n🎯 Final Accuracy: {accuracy:.2%}")

0 commit comments

Comments
 (0)