Skip to content

Commit fda3768

Browse files
committed
fix format
1 parent 40c886c commit fda3768

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

rllm/engine/agent_execution_engine.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
203203
messages = agent.chat_completions
204204
prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True)
205205
prompt_token_len = len(prompt_tokens)
206-
# Note, this should never happen!
206+
207+
# Check if initial prompt already exceeds max length
208+
# This can happen if:
209+
# 1. Dataset filtering didn't catch this sample (e.g., different tokenization)
210+
# 2. Checkpoint contains cached dataset that wasn't filtered (delete checkpoint's data.pt)
207211
if prompt_token_len > self.max_prompt_length:
208-
agent.reset()
209-
raise Exception(f"Trajectory {idx}: initial prompt length {prompt_token_len} already exceeded max_prompt_length {self.max_prompt_length}, retrying")
212+
logger.warning(f"Trajectory {idx}: Initial prompt length {prompt_token_len} exceeds max_prompt_length {self.max_prompt_length}. Skipping this sample entirely (no trajectory will be returned). First 200 chars of prompt: {self.chat_parser.parse(messages[:1], add_generation_prompt=False)[:200]}...")
213+
214+
# Close the environment and return None to skip this trajectory entirely
215+
await loop.run_in_executor(self.executor, env.close)
216+
return None
210217

211218
for step_idx in range(self.max_steps):
212219
# Get action from agent
@@ -410,7 +417,11 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
410417
async def run_agent_trajectory_with_retry(self, idx, application_id, seed=0, mode="Text", **kwargs):
411418
for _ in range(self.retry_limit):
412419
try:
413-
return await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200)
420+
result = await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200)
421+
# If result is None, it means the trajectory was skipped (e.g., overlong prompt)
422+
if result is None:
423+
return None
424+
return result
414425
except Exception:
415426
traceback.print_exc()
416427
continue
@@ -452,10 +463,18 @@ async def launch_one_trajectory_task(env_idx: int):
452463
tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))]
453464

454465
tasks_completed = 0
466+
skipped_count = 0
455467
for coro in asyncio.as_completed(tasks_to_run):
456468
try:
457469
result = await coro
458470
tasks_completed += 1
471+
472+
# Skip None results (trajectories that were skipped due to overlong prompts)
473+
if result is None:
474+
skipped_count += 1
475+
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed ({skipped_count} skipped due to overlong prompts)", "cyan")
476+
continue
477+
459478
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan")
460479
yield result
461480
except Exception as e:

rllm/trainer/verl/agent_ppo_trainer.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,18 @@ def fit_agent(self):
180180
batch = self._pad_dataproto_to_world_size(batch=batch)
181181
else:
182182
final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info)
183+
184+
# If some trajectories were skipped (overlong prompts), filter the batch to match
185+
if "skipped_indices" in final_gen_batch_output.meta_info:
186+
skipped_indices = final_gen_batch_output.meta_info.pop("skipped_indices")
187+
# Create mask for valid (non-skipped) indices
188+
valid_mask = np.ones(len(batch.batch), dtype=bool)
189+
valid_mask[skipped_indices] = False
190+
# Filter batch to only include valid samples
191+
valid_indices = np.where(valid_mask)[0]
192+
batch = batch.select_idxs(valid_indices)
193+
print(f"Filtered batch from {len(valid_mask)} to {len(valid_indices)} samples after skipping {len(skipped_indices)} overlong prompts")
194+
183195
batch = batch.union(final_gen_batch_output)
184196
metrics.update(generate_metrics)
185197

@@ -551,16 +563,36 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None):
551563
trajectories = []
552564
if self.async_rollout_mode:
553565
gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=meta_info, mode="Token")
554-
for _, trajectory in enumerate(gen_seq_generator):
555-
trajectories.append(trajectory)
566+
for trajectory in gen_seq_generator:
567+
# Skip None trajectories (overlong prompts)
568+
if trajectory is not None:
569+
trajectories.append(trajectory)
556570
else:
557571
raise ValueError("Only async rollout mode is supported")
572+
573+
# Check if all trajectories were skipped
574+
if not trajectories:
575+
raise RuntimeError("All trajectories were skipped (likely all prompts exceed max_prompt_length). Please check your dataset and increase max_prompt_length or enable filtering.")
576+
558577
# Sort trajectories by their idx, to ensure they are in order.
559578
trajectories.sort(key=lambda x: x["idx"])
560579

580+
# Determine which indices were skipped by checking missing idx values
581+
# Expected indices are 0 to (batch_size * rollout.n - 1)
582+
expected_count = len(self.agent_execution_engine.envs)
583+
actual_indices = set(t["idx"] for t in trajectories)
584+
expected_indices = set(range(expected_count))
585+
skipped_indices = sorted(expected_indices - actual_indices)
586+
587+
if skipped_indices:
588+
print(f"Skipped {len(skipped_indices)} trajectories due to overlong prompts at env indices: {skipped_indices}")
589+
561590
with marked_timer("transform_trajectory", timing_raw):
562591
# Transform the raw trajectories into DataProto format.
563592
final_gen_batch_output, metrics = self._transform_agent_trajectories(trajectories)
593+
# Store skipped indices in meta_info for potential filtering of original batch
594+
if skipped_indices:
595+
final_gen_batch_output.meta_info["skipped_indices"] = skipped_indices
564596
return final_gen_batch_output, metrics
565597

566598
def generate_agent_steps(self, timing_raw=None, meta_info=None, uids=None):

0 commit comments

Comments
 (0)