|
| 1 | +""" |
| 2 | +This workflow bridges eval-protocol's MCPGymRolloutProcessor with rllm-fw's Workflow pattern |
| 3 | +for the FrozenLake environment. |
| 4 | +""" |
| 5 | + |
| 6 | +import asyncio |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | +import eval_protocol |
| 10 | +from eval_protocol.benchmarks.test_frozen_lake import test_frozen_lake_evaluation |
| 11 | +from eval_protocol.models import EvaluationRow, InputMetadata, Message |
| 12 | +from eval_protocol.pytest.default_mcp_gym_rollout_processor import ( |
| 13 | + MCPGymRolloutProcessor, |
| 14 | +) |
| 15 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
| 16 | + |
| 17 | +from rllm.agents.agent import Episode, Step, Trajectory |
| 18 | +from rllm.engine.rollout.openai_engine import OpenAIEngine |
| 19 | +from rllm.workflows.workflow import Workflow |
| 20 | + |
| 21 | + |
| 22 | +class FrozenLakeWorkflow(Workflow): |
| 23 | + """ |
| 24 | + Workflow that executes frozen lake tasks using MCPGymRolloutProcessor. |
| 25 | +
|
| 26 | + Task format expected: |
| 27 | + { |
| 28 | + "id": "frozen_lake_task_0", |
| 29 | + "system_prompt": "...", |
| 30 | + "environment_context": {...}, |
| 31 | + "user_prompt_template": "{observation}" |
| 32 | + } |
| 33 | + """ |
| 34 | + |
| 35 | + # Class variables (shared across all workflow instances) |
| 36 | + _shared_server_started = False |
| 37 | + _server_lock = asyncio.Lock() |
| 38 | + _shared_rollout_processor = MCPGymRolloutProcessor() |
| 39 | + |
| 40 | + def __init__(self, rollout_engine: OpenAIEngine, lite_llm_prefix: str = "fireworks_ai/", max_steps: int = 30, temperature: float = 1.0, max_tokens: int = 4096, **kwargs): |
| 41 | + super().__init__(rollout_engine, **kwargs) |
| 42 | + |
| 43 | + self._rollout_processor_server_started = False |
| 44 | + self._rollout_processor_semaphore = asyncio.Semaphore(1) |
| 45 | + self._lite_llm_prefix = lite_llm_prefix |
| 46 | + self._temperature = temperature |
| 47 | + self._max_tokens = max_tokens |
| 48 | + self._max_steps = max_steps |
| 49 | + |
| 50 | + eval_protocol_path = Path(eval_protocol.__file__).parent |
| 51 | + self._server_script_path = eval_protocol_path / "mcp_servers" / "frozen_lake" / "server.py" |
| 52 | + |
| 53 | + # Use shared rollout processor across all instances |
| 54 | + self.rollout_processor = FrozenLakeWorkflow._shared_rollout_processor |
| 55 | + |
| 56 | + def _build_rollout_processor_config(self): |
| 57 | + model = self._lite_llm_prefix + self.rollout_engine.model |
| 58 | + print("model in frozen_lake_flow", model) |
| 59 | + return RolloutProcessorConfig( |
| 60 | + completion_params={ |
| 61 | + "model": model, |
| 62 | + "temperature": self._temperature, |
| 63 | + "max_tokens": self._max_tokens, |
| 64 | + }, |
| 65 | + mcp_config_path="", |
| 66 | + server_script_path=str(self._server_script_path), |
| 67 | + steps=self._max_steps, |
| 68 | + semaphore=self._rollout_processor_semaphore, |
| 69 | + kwargs={"start_server": self._rollout_processor_server_started}, |
| 70 | + ) |
| 71 | + |
| 72 | + async def run(self, task: dict, uid: str, **kwargs) -> Episode: |
| 73 | + """ |
| 74 | + Execute the frozen lake workflow. |
| 75 | +
|
| 76 | + Args: |
| 77 | + task: Dict containing frozen lake task data |
| 78 | + uid: Unique identifier for this episode |
| 79 | + **kwargs: Additional arguments |
| 80 | +
|
| 81 | + Returns: |
| 82 | + Episode with trajectory and computed rewards |
| 83 | + """ |
| 84 | + # Thread-safe server startup (double-checked locking pattern) |
| 85 | + if not FrozenLakeWorkflow._shared_server_started: |
| 86 | + # Only acquire lock if server not started yet |
| 87 | + async with FrozenLakeWorkflow._server_lock: |
| 88 | + # Check again inside lock (another workflow might have started it) |
| 89 | + if not FrozenLakeWorkflow._shared_server_started: |
| 90 | + # First workflow to reach here starts the server |
| 91 | + self._rollout_processor_server_started = True |
| 92 | + FrozenLakeWorkflow._shared_server_started = True |
| 93 | + else: |
| 94 | + self._rollout_processor_server_started = False |
| 95 | + else: |
| 96 | + self._rollout_processor_server_started = False |
| 97 | + |
| 98 | + self.reset(task=task, uid=uid) |
| 99 | + |
| 100 | + try: |
| 101 | + eval_row = self._task_to_evaluation_row(task) |
| 102 | + |
| 103 | + tasks = self.rollout_processor([eval_row], self._build_rollout_processor_config()) |
| 104 | + |
| 105 | + if not tasks: |
| 106 | + raise ValueError("MCPGymRolloutProcessor returned no tasks") |
| 107 | + |
| 108 | + result_row: EvaluationRow = await tasks[0] |
| 109 | + |
| 110 | + episode = await self._evaluate_and_create_episode(result_row, task, uid) |
| 111 | + |
| 112 | + return episode |
| 113 | + |
| 114 | + except Exception as e: |
| 115 | + # Gracefully handle failures - return a failed Episode instead of crashing |
| 116 | + print(f"⚠️ Task {uid} failed: {e}") |
| 117 | + |
| 118 | + failed_episode = Episode( |
| 119 | + id=uid, |
| 120 | + task=task, |
| 121 | + is_correct=False, |
| 122 | + trajectories=[], |
| 123 | + metrics={"frozen_lake_reward": 0.0, "error": str(e)}, |
| 124 | + ) |
| 125 | + return failed_episode |
| 126 | + |
| 127 | + def _task_to_evaluation_row(self, task: dict) -> EvaluationRow: |
| 128 | + """Convert rllm task dict to eval protocol EvaluationRow.""" |
| 129 | + return EvaluationRow( |
| 130 | + messages=[Message(role="system", content=task["system_prompt"])], |
| 131 | + input_metadata=InputMetadata( |
| 132 | + row_id=task["id"], |
| 133 | + dataset_info={ |
| 134 | + "environment_context": task["environment_context"], |
| 135 | + "user_prompt_template": task["user_prompt_template"], |
| 136 | + }, |
| 137 | + ), |
| 138 | + ) |
| 139 | + |
| 140 | + async def _evaluate_and_create_episode( |
| 141 | + self, |
| 142 | + row: EvaluationRow, |
| 143 | + task: dict, |
| 144 | + uid: str, |
| 145 | + ) -> Episode: |
| 146 | + """ |
| 147 | + Evaluate the rollout and convert to rllm Episode. |
| 148 | + """ |
| 149 | + # Call the evaluation function |
| 150 | + evaluated_row: EvaluationRow = await test_frozen_lake_evaluation(row) |
| 151 | + |
| 152 | + # Extract reward and metrics from evaluation_result |
| 153 | + if evaluated_row.evaluation_result is None: |
| 154 | + raise ValueError("Evaluation function did not return a result") |
| 155 | + |
| 156 | + reward = evaluated_row.evaluation_result.score |
| 157 | + reward_info = evaluated_row.evaluation_result.metrics or {} |
| 158 | + |
| 159 | + def msg_to_dict(msg: Message) -> dict: |
| 160 | + """Convert eval_protocol Message to chat completion dict.""" |
| 161 | + d = {"role": msg.role, "content": msg.content} |
| 162 | + if msg.tool_calls: |
| 163 | + d["tool_calls"] = [ |
| 164 | + { |
| 165 | + "id": tc.id, |
| 166 | + "type": tc.type, |
| 167 | + "function": { |
| 168 | + "name": tc.function.name, |
| 169 | + "arguments": tc.function.arguments, |
| 170 | + }, |
| 171 | + } |
| 172 | + for tc in msg.tool_calls |
| 173 | + ] |
| 174 | + if msg.tool_call_id: |
| 175 | + d["tool_call_id"] = msg.tool_call_id |
| 176 | + if msg.name: |
| 177 | + d["name"] = msg.name |
| 178 | + return d |
| 179 | + |
| 180 | + trajectory = Trajectory() |
| 181 | + all_messages = [] |
| 182 | + |
| 183 | + for msg in row.messages: |
| 184 | + msg_dict = msg_to_dict(msg) |
| 185 | + all_messages.append(msg_dict) |
| 186 | + |
| 187 | + # Create Step with only observation and chat_completions for user or tool message |
| 188 | + if msg.role in ["user", "tool"]: |
| 189 | + new_step = Step(observation=str(msg.content or ""), chat_completions=all_messages.copy()) |
| 190 | + trajectory.steps.append(new_step) |
| 191 | + |
| 192 | + # Create new Step with action/response for assistant message |
| 193 | + elif msg.role == "assistant": |
| 194 | + # Extract action: tool calls if present, otherwise message content |
| 195 | + action_data = msg_dict.get("tool_calls") if msg.tool_calls else str(msg.content or "") |
| 196 | + |
| 197 | + new_step = Step( |
| 198 | + model_response=str(msg.content) if msg.content else "", |
| 199 | + action=action_data, |
| 200 | + chat_completions=all_messages.copy(), |
| 201 | + ) |
| 202 | + trajectory.steps.append(new_step) |
| 203 | + |
| 204 | + # Assign final reward to the last step (sparse reward) |
| 205 | + if trajectory.steps: |
| 206 | + trajectory.steps[-1].reward = reward |
| 207 | + trajectory.steps[-1].info = reward_info |
| 208 | + |
| 209 | + trajectory.reward = reward |
| 210 | + trajectory.task = task |
| 211 | + |
| 212 | + # Create episode |
| 213 | + episode = Episode( |
| 214 | + id=uid, |
| 215 | + task=task, |
| 216 | + is_correct=(reward == 1.0), |
| 217 | + trajectories=[trajectory], |
| 218 | + metrics={"frozen_lake_reward": reward, **reward_info}, |
| 219 | + ) |
| 220 | + |
| 221 | + return episode |
| 222 | + |
| 223 | + def cleanup(self): |
| 224 | + """Cleanup MCP server resources.""" |
| 225 | + if self.rollout_processor: |
| 226 | + self.rollout_processor.cleanup() |
| 227 | + self.rollout_processor = None |
0 commit comments