Skip to content

Commit 25577bb

Browse files
committed
minor fix
1 parent d01af13 commit 25577bb

File tree

2 files changed

+69
-31
lines changed

2 files changed

+69
-31
lines changed

rllm/engine/agent_execution_engine.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text",
437437
self.executor = ThreadPoolExecutor(max_workers=self.max_env_workers)
438438
semaphore = asyncio.Semaphore(self.n_parallel_agents)
439439

440+
# Initialize skipped indices and valid indices lists (will be populated as trajectories complete)
441+
self._last_skipped_indices = []
442+
self._last_valid_indices = None # Will be computed after all trajectories complete
443+
440444
if self.engine_name == "verl":
441445
self.rollout_engine.wake_up()
442446

@@ -456,30 +460,55 @@ async def launch_one_trajectory_task(env_idx: int):
456460

457461
traceback.print_exc()
458462
raise e
459-
return result
463+
# Return tuple (env_idx, result) so we can track which env returned None
464+
return (env_idx, result)
460465

461466
# Create all N conceptual tasks. Their execution will be throttled by the semaphore
462467
# and the availability of agent/env indices.
463468
tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))]
464469

470+
# Track results by index to maintain order and identify skipped trajectories
471+
results_by_idx = {}
465472
tasks_completed = 0
466473
skipped_count = 0
474+
467475
for coro in asyncio.as_completed(tasks_to_run):
468476
try:
469-
result = await coro
477+
env_idx, result = await coro
470478
tasks_completed += 1
471479

472-
# Skip None results (trajectories that were skipped due to overlong prompts)
480+
# Store result with its env_idx (None if skipped)
473481
if result is None:
474482
skipped_count += 1
483+
results_by_idx[env_idx] = None # Store None to mark as skipped
475484
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed ({skipped_count} skipped due to overlong prompts)", "cyan")
476-
continue
477-
478-
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan")
479-
yield result
485+
else:
486+
results_by_idx[env_idx] = result
487+
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan")
480488
except Exception as e:
481489
raise e
482490

491+
# Verify all tasks completed and are stored
492+
if len(results_by_idx) != len(self.envs):
493+
missing = sorted(set(range(len(self.envs))) - set(results_by_idx.keys()))
494+
raise RuntimeError(f"Not all trajectories were stored! Missing indices: {missing}. Expected {len(self.envs)} but got {len(results_by_idx)}")
495+
496+
# Yield all trajectories in order (0 to len(self.envs)-1)
497+
# None values indicate skipped trajectories
498+
skipped_indices = []
499+
for idx in range(len(self.envs)):
500+
# All indices should be in results_by_idx after the check above
501+
result = results_by_idx[idx]
502+
if result is None:
503+
skipped_indices.append(idx)
504+
yield result # Yield result (None for skipped, trajectory dict otherwise)
505+
506+
# Store skipped indices and valid indices as instance variables for trainer to access
507+
self._last_skipped_indices = skipped_indices
508+
# Compute valid indices (complement of skipped indices) for easier batch filtering
509+
total_count = len(self.envs)
510+
self._last_valid_indices = [i for i in range(total_count) if i not in skipped_indices]
511+
483512
if self.engine_name == "verl":
484513
self.rollout_engine.sleep()
485514

rllm/trainer/verl/agent_ppo_trainer.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def fit_agent(self):
153153

154154
for epoch in range(self.config.trainer.total_epochs):
155155
pprint(f"epoch {epoch}, step {self.global_steps} started")
156+
156157
for batch_dict in self.train_dataloader:
157158
batch: DataProto = DataProto.from_single_dict(batch_dict)
158159
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
@@ -182,15 +183,25 @@ def fit_agent(self):
182183
final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info)
183184

184185
# If some trajectories were skipped (overlong prompts), filter the batch to match
185-
if "skipped_indices" in final_gen_batch_output.meta_info:
186-
skipped_indices = final_gen_batch_output.meta_info.pop("skipped_indices")
187-
# Create mask for valid (non-skipped) indices
188-
valid_mask = np.ones(len(batch.batch), dtype=bool)
189-
valid_mask[skipped_indices] = False
190-
# Filter batch to only include valid samples
191-
valid_indices = np.where(valid_mask)[0]
186+
# Get valid indices directly from AgentExecutionEngine (handled internally)
187+
valid_indices = getattr(self.agent_execution_engine, "_last_valid_indices", None)
188+
skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", [])
189+
190+
# Ensure batch size matches the number of trajectories collected
191+
num_trajectories = len(final_gen_batch_output.batch)
192+
if valid_indices is not None and len(valid_indices) != num_trajectories:
193+
if len(valid_indices) > num_trajectories:
194+
valid_indices = valid_indices[:num_trajectories]
195+
else:
196+
raise RuntimeError(f"Fewer valid indices ({len(valid_indices)}) than trajectories ({num_trajectories}).")
197+
198+
if valid_indices is not None and len(valid_indices) < len(batch.batch):
199+
# Filter batch to only include valid samples (matching the number of trajectories collected)
192200
batch = batch.select_idxs(valid_indices)
193-
print(f"Filtered batch from {len(valid_mask)} to {len(valid_indices)} samples after skipping {len(skipped_indices)} overlong prompts")
201+
202+
# Final sanity check: batch sizes must match before union
203+
if len(batch.batch) != len(final_gen_batch_output.batch):
204+
raise RuntimeError(f"Batch size mismatch before union: batch has {len(batch.batch)} samples, final_gen_batch_output has {len(final_gen_batch_output.batch)} samples. valid_indices: {len(valid_indices) if valid_indices else 'None'}, skipped_indices: {len(skipped_indices)}")
194205

195206
batch = batch.union(final_gen_batch_output)
196207
metrics.update(generate_metrics)
@@ -564,7 +575,7 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None):
564575
if self.async_rollout_mode:
565576
gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=meta_info, mode="Token")
566577
for trajectory in gen_seq_generator:
567-
# Skip None trajectories (overlong prompts)
578+
# Skip None trajectories (overlong prompts) - these are handled by AgentExecutionEngine
568579
if trajectory is not None:
569580
trajectories.append(trajectory)
570581
else:
@@ -574,25 +585,18 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None):
574585
if not trajectories:
575586
raise RuntimeError("All trajectories were skipped (likely all prompts exceed max_prompt_length). Please check your dataset and increase max_prompt_length or enable filtering.")
576587

577-
# Sort trajectories by their idx, to ensure they are in order.
578-
trajectories.sort(key=lambda x: x["idx"])
579-
580-
# Determine which indices were skipped by checking missing idx values
581-
# Expected indices are 0 to (batch_size * rollout.n - 1)
582-
expected_count = len(self.agent_execution_engine.envs)
583-
actual_indices = set(t["idx"] for t in trajectories)
584-
expected_indices = set(range(expected_count))
585-
skipped_indices = sorted(expected_indices - actual_indices)
586-
588+
# Get skipped indices from AgentExecutionEngine (handled internally)
589+
skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", [])
587590
if skipped_indices:
588591
print(f"Skipped {len(skipped_indices)} trajectories due to overlong prompts at env indices: {skipped_indices}")
589592

593+
# Sort trajectories by their idx, to ensure they are in order.
594+
trajectories.sort(key=lambda x: x["idx"])
595+
590596
with marked_timer("transform_trajectory", timing_raw):
591597
# Transform the raw trajectories into DataProto format.
592598
final_gen_batch_output, metrics = self._transform_agent_trajectories(trajectories)
593-
# Store skipped indices in meta_info for potential filtering of original batch
594-
if skipped_indices:
595-
final_gen_batch_output.meta_info["skipped_indices"] = skipped_indices
599+
596600
return final_gen_batch_output, metrics
597601

598602
def generate_agent_steps(self, timing_raw=None, meta_info=None, uids=None):
@@ -865,18 +869,23 @@ def generate_agent_trajectories_async(self, timing_raw=None, meta_info=None, mod
865869
timing_raw = {}
866870
queue = Queue()
867871

872+
# Create a unique sentinel object to signal completion
873+
# (Cannot use None since None is used for skipped trajectories)
874+
_SENTINEL = object()
875+
868876
def runner():
869877
async def consume():
870878
async for item in self.agent_execution_engine.trajectory_generator(timing_raw=timing_raw, mode=mode, meta_info=meta_info):
871879
queue.put(item)
872-
queue.put(None) # sentinel to signal done
880+
# Use a special sentinel object instead of None (since None is used for skipped trajectories)
881+
queue.put(_SENTINEL) # sentinel to signal done
873882

874883
asyncio.run(consume())
875884

876885
Thread(target=runner, daemon=True).start()
877886
while True:
878887
item = queue.get()
879-
if item is None:
888+
if item is _SENTINEL:
880889
break
881890
yield item
882891

0 commit comments

Comments
 (0)