Skip to content

Commit 0360a25

Browse files
Merge pull request #272 from thwu1/nightly
Fix retokenization
2 parents cd97c06 + 3306c3c commit 0360a25

File tree

3 files changed

+78
-6
lines changed

3 files changed

+78
-6
lines changed

rllm/engine/agent_execution_engine.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str:
144144

145145
if self.engine_name == "openai":
146146
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, enforce_max_prompt_length=False, **sampling_params)
147-
return output.text
147+
return output
148148
elif self.engine_name == "verl":
149149
meta_data = sampling_params.pop("meta_info", {})
150150
validate = meta_data.get("validate", False)
151151
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, enforce_max_prompt_length=False, **sampling_params)
152-
return output.text
152+
return output
153153
else:
154154
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")
155155

@@ -232,14 +232,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
232232
kwargs["max_tokens"] = max_tokens
233233

234234
start_time = time.time()
235-
response = await self.get_model_response(prompt_messages, application_id, **kwargs)
235+
model_output = await self.get_model_response(prompt_messages, application_id, **kwargs)
236+
response = model_output.text
236237
delta_time = time.time() - start_time
237238
llm_time += delta_time
238239
total_time += delta_time
239240
# Update steps
240241
prompt_response_pair = {
241242
"prompt": self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True),
242243
"response": response,
244+
"prompt_ids": model_output.prompt_ids,
245+
"completion_ids": model_output.completion_ids,
243246
}
244247
episode_steps.append(prompt_response_pair)
245248

@@ -379,10 +382,11 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
379382
if mode == "Text":
380383
return trajectory
381384
elif mode == "Token":
385+
prompt_tokens, response_tokens, response_masks, is_valid_trajectory = self.assemble_steps(episode_steps)
382386
token_result = {
383-
"prompt_tokens": torch.tensor(prompt_tokens, dtype=torch.long),
384-
"response_tokens": torch.tensor(response_tokens, dtype=torch.long),
385-
"response_masks": torch.tensor(response_masks, dtype=torch.long),
387+
"prompt_tokens": prompt_tokens,
388+
"response_tokens": response_tokens,
389+
"response_masks": response_masks,
386390
"trajectory_reward": trajectory.reward,
387391
"idx": env.idx,
388392
"chat_completions": agent.chat_completions,
@@ -397,6 +401,7 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
397401
"llm_time": llm_time,
398402
# Total time spent in the trajectory
399403
"total_time": total_time,
404+
"token_mismatch": 0.0 if is_valid_trajectory else 1.0,
400405
},
401406
}
402407
return token_result
@@ -410,6 +415,71 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
410415
"mc_returns": [step.mc_return for step in trajectory.steps][: len(episode_steps)],
411416
}
412417
return steps_result
418+
else:
419+
raise ValueError(f"Mode {mode} not supported")
420+
421+
def assemble_steps(self, steps: list[dict]):
422+
"""
423+
Transform step-by-step results into trajectory format for training.
424+
The assemble is aggresive, if steps is not cumulative, the response_masks is set to all 0s.
425+
426+
Each step_result contains:
427+
- steps: List of {"prompt": str, "response": str, "prompt_ids": list, "completion_ids": list}
428+
429+
For training, we need to assemble the full conversation sequence where:
430+
- prompt_tokens: Initial prompt (first step's prompt_ids)
431+
- response_tokens: All subsequent conversation (completion_ids + next step's prompt_ids)
432+
- response_masks: Mask indicating which tokens contribute to loss (only completion_ids)
433+
"""
434+
435+
# Start with initial prompt from first step
436+
initial_prompt_ids = steps[0]["prompt_ids"]
437+
accumulated_sequence = initial_prompt_ids.copy()
438+
response_tokens = []
439+
response_masks = []
440+
is_valid_trajectory = True
441+
442+
for i, step in enumerate(steps):
443+
current_prompt_ids = step["prompt_ids"]
444+
current_completion_ids = step["completion_ids"]
445+
446+
if i == 0:
447+
# First step: just add completion
448+
response_tokens.extend(current_completion_ids)
449+
response_masks.extend([1] * len(current_completion_ids)) # completion contributes to loss
450+
accumulated_sequence.extend(current_completion_ids)
451+
else:
452+
if current_prompt_ids[: len(accumulated_sequence)] != accumulated_sequence:
453+
# Find the first differing position
454+
prefix = current_prompt_ids[: len(accumulated_sequence)]
455+
diff_pos = None
456+
for i, (expected, actual) in enumerate(zip(accumulated_sequence, prefix, strict=False)):
457+
if expected != actual:
458+
diff_pos = i
459+
break
460+
461+
if diff_pos is not None:
462+
logger.warning(f"When assemble steps, detect the trajectory not accumulative at position {diff_pos}. Expected: {accumulated_sequence[diff_pos : diff_pos + 5]}, Got: {prefix[diff_pos : diff_pos + 5]}. Setting response_masks to all 0s. This is likely due to retokenization.")
463+
else:
464+
logger.warning(f"When assemble steps, detect length mismatch. Expected length: {len(accumulated_sequence)}, Got length: {len(prefix)}. Setting response_masks to all 0s.")
465+
466+
is_valid_trajectory = False
467+
break
468+
469+
response_tokens.extend(current_prompt_ids[len(accumulated_sequence) :] + current_completion_ids)
470+
response_masks.extend([0] * (len(current_prompt_ids) - len(accumulated_sequence)) + [1] * len(current_completion_ids)) # completion contributes to loss
471+
accumulated_sequence = current_prompt_ids + current_completion_ids
472+
473+
assert len(response_masks) == len(response_tokens)
474+
475+
prompt_tokens = torch.tensor(initial_prompt_ids, dtype=torch.long)
476+
response_tokens = torch.tensor(response_tokens, dtype=torch.long)
477+
response_masks = torch.tensor(response_masks, dtype=torch.long)
478+
479+
if self.config.rllm.filter_token_mismatch:
480+
response_masks = response_masks * int(is_valid_trajectory)
481+
482+
return prompt_tokens, response_tokens, response_masks, is_valid_trajectory
413483

414484
async def run_agent_trajectory_with_retry(self, idx, application_id, seed=0, mode="Text", **kwargs):
415485
for _ in range(self.retry_limit):

rllm/trainer/config/agent_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ rllm:
4444
disable_thinking: False
4545
accumulate_reasoning: False
4646
mask_truncated_samples: False
47+
filter_token_mismatch: True
4748
stepwise_advantage:
4849
enable: False
4950
mode: broadcast # [broadcast, per_step]

rllm/trainer/verl/agent_ppo_trainer_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def create_replay_queue(generator, q, batch_iter_val, timing_raw_val):
139139
# Get the generator function which will yield results as they complete
140140
if self.config.rllm.agent.step_advantage_broadcast:
141141
raise Exception("Stepwise advantage broadcasting not supported on pipelined trainer yet")
142+
142143
gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=batch.meta_info)
143144
thread = threading.Thread(target=create_replay_queue, args=(gen_seq_generator, replay_queue, batch_iter, timing_raw))
144145
thread.start()

0 commit comments

Comments
 (0)