@@ -153,6 +153,7 @@ def fit_agent(self):
153153
154154 for epoch in range (self .config .trainer .total_epochs ):
155155 pprint (f"epoch { epoch } , step { self .global_steps } started" )
156+
156157 for batch_dict in self .train_dataloader :
157158 batch : DataProto = DataProto .from_single_dict (batch_dict )
158159 batch .non_tensor_batch ["uid" ] = np .array ([str (uuid .uuid4 ()) for _ in range (len (batch .batch ))], dtype = object )
@@ -182,15 +183,25 @@ def fit_agent(self):
182183 final_gen_batch_output , generate_metrics = self .generate_agent_trajectory (timing_raw = timing_raw , meta_info = batch .meta_info )
183184
184185 # If some trajectories were skipped (overlong prompts), filter the batch to match
185- if "skipped_indices" in final_gen_batch_output .meta_info :
186- skipped_indices = final_gen_batch_output .meta_info .pop ("skipped_indices" )
187- # Create mask for valid (non-skipped) indices
188- valid_mask = np .ones (len (batch .batch ), dtype = bool )
189- valid_mask [skipped_indices ] = False
190- # Filter batch to only include valid samples
191- valid_indices = np .where (valid_mask )[0 ]
186+ # Get valid indices directly from AgentExecutionEngine (handled internally)
187+ valid_indices = getattr (self .agent_execution_engine , "_last_valid_indices" , None )
188+ skipped_indices = getattr (self .agent_execution_engine , "_last_skipped_indices" , [])
189+
190+ # Ensure batch size matches the number of trajectories collected
191+ num_trajectories = len (final_gen_batch_output .batch )
192+ if valid_indices is not None and len (valid_indices ) != num_trajectories :
193+ if len (valid_indices ) > num_trajectories :
194+ valid_indices = valid_indices [:num_trajectories ]
195+ else :
196+ raise RuntimeError (f"Fewer valid indices ({ len (valid_indices )} ) than trajectories ({ num_trajectories } )." )
197+
198+ if valid_indices is not None and len (valid_indices ) < len (batch .batch ):
199+ # Filter batch to only include valid samples (matching the number of trajectories collected)
192200 batch = batch .select_idxs (valid_indices )
193- print (f"Filtered batch from { len (valid_mask )} to { len (valid_indices )} samples after skipping { len (skipped_indices )} overlong prompts" )
201+
202+ # Final sanity check: batch sizes must match before union
203+ if len (batch .batch ) != len (final_gen_batch_output .batch ):
204+ raise RuntimeError (f"Batch size mismatch before union: batch has { len (batch .batch )} samples, final_gen_batch_output has { len (final_gen_batch_output .batch )} samples. valid_indices: { len (valid_indices ) if valid_indices else 'None' } , skipped_indices: { len (skipped_indices )} " )
194205
195206 batch = batch .union (final_gen_batch_output )
196207 metrics .update (generate_metrics )
@@ -564,7 +575,7 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None):
564575 if self .async_rollout_mode :
565576 gen_seq_generator = self .generate_agent_trajectories_async (timing_raw = timing_raw , meta_info = meta_info , mode = "Token" )
566577 for trajectory in gen_seq_generator :
567- # Skip None trajectories (overlong prompts)
578+ # Skip None trajectories (overlong prompts) - these are handled by AgentExecutionEngine
568579 if trajectory is not None :
569580 trajectories .append (trajectory )
570581 else :
@@ -574,25 +585,18 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None):
574585 if not trajectories :
575586 raise RuntimeError ("All trajectories were skipped (likely all prompts exceed max_prompt_length). Please check your dataset and increase max_prompt_length or enable filtering." )
576587
577- # Sort trajectories by their idx, to ensure they are in order.
578- trajectories .sort (key = lambda x : x ["idx" ])
579-
580- # Determine which indices were skipped by checking missing idx values
581- # Expected indices are 0 to (batch_size * rollout.n - 1)
582- expected_count = len (self .agent_execution_engine .envs )
583- actual_indices = set (t ["idx" ] for t in trajectories )
584- expected_indices = set (range (expected_count ))
585- skipped_indices = sorted (expected_indices - actual_indices )
586-
588+ # Get skipped indices from AgentExecutionEngine (handled internally)
589+ skipped_indices = getattr (self .agent_execution_engine , "_last_skipped_indices" , [])
587590 if skipped_indices :
588591 print (f"Skipped { len (skipped_indices )} trajectories due to overlong prompts at env indices: { skipped_indices } " )
589592
593+ # Sort trajectories by their idx, to ensure they are in order.
594+ trajectories .sort (key = lambda x : x ["idx" ])
595+
590596 with marked_timer ("transform_trajectory" , timing_raw ):
591597 # Transform the raw trajectories into DataProto format.
592598 final_gen_batch_output , metrics = self ._transform_agent_trajectories (trajectories )
593- # Store skipped indices in meta_info for potential filtering of original batch
594- if skipped_indices :
595- final_gen_batch_output .meta_info ["skipped_indices" ] = skipped_indices
599+
596600 return final_gen_batch_output , metrics
597601
598602 def generate_agent_steps (self , timing_raw = None , meta_info = None , uids = None ):
@@ -865,18 +869,23 @@ def generate_agent_trajectories_async(self, timing_raw=None, meta_info=None, mod
865869 timing_raw = {}
866870 queue = Queue ()
867871
872+ # Create a unique sentinel object to signal completion
873+ # (Cannot use None since None is used for skipped trajectories)
874+ _SENTINEL = object ()
875+
868876 def runner ():
869877 async def consume ():
870878 async for item in self .agent_execution_engine .trajectory_generator (timing_raw = timing_raw , mode = mode , meta_info = meta_info ):
871879 queue .put (item )
872- queue .put (None ) # sentinel to signal done
880+ # Use a special sentinel object instead of None (since None is used for skipped trajectories)
881+ queue .put (_SENTINEL ) # sentinel to signal done
873882
874883 asyncio .run (consume ())
875884
876885 Thread (target = runner , daemon = True ).start ()
877886 while True :
878887 item = queue .get ()
879- if item is None :
888+ if item is _SENTINEL :
880889 break
881890 yield item
882891
0 commit comments