@@ -144,12 +144,12 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str:
144144
145145 if self .engine_name == "openai" :
146146 output = await self .rollout_engine .get_model_response (prompt , application_id = application_id , enforce_max_prompt_length = False , ** sampling_params )
147- return output . text
147+ return output
148148 elif self .engine_name == "verl" :
149149 meta_data = sampling_params .pop ("meta_info" , {})
150150 validate = meta_data .get ("validate" , False )
151151 output = await self .rollout_engine .get_model_response (prompt , application_id = application_id , validate = validate , enforce_max_prompt_length = False , ** sampling_params )
152- return output . text
152+ return output
153153 else :
154154 raise NotImplementedError (f"Engine type '{ self .engine_name } ' not supported" )
155155
@@ -232,14 +232,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
232232 kwargs ["max_tokens" ] = max_tokens
233233
234234 start_time = time .time ()
235- response = await self .get_model_response (prompt_messages , application_id , ** kwargs )
235+ model_output = await self .get_model_response (prompt_messages , application_id , ** kwargs )
236+ response = model_output .text
236237 delta_time = time .time () - start_time
237238 llm_time += delta_time
238239 total_time += delta_time
239240 # Update steps
240241 prompt_response_pair = {
241242 "prompt" : self .chat_parser .parse (prompt_messages , add_generation_prompt = True , is_first_msg = True ),
242243 "response" : response ,
244+ "prompt_ids" : model_output .prompt_ids ,
245+ "completion_ids" : model_output .completion_ids ,
243246 }
244247 episode_steps .append (prompt_response_pair )
245248
@@ -379,10 +382,11 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
379382 if mode == "Text" :
380383 return trajectory
381384 elif mode == "Token" :
385+ prompt_tokens , response_tokens , response_masks , is_valid_trajectory = self .assemble_steps (episode_steps )
382386 token_result = {
383- "prompt_tokens" : torch . tensor ( prompt_tokens , dtype = torch . long ) ,
384- "response_tokens" : torch . tensor ( response_tokens , dtype = torch . long ) ,
385- "response_masks" : torch . tensor ( response_masks , dtype = torch . long ) ,
387+ "prompt_tokens" : prompt_tokens ,
388+ "response_tokens" : response_tokens ,
389+ "response_masks" : response_masks ,
386390 "trajectory_reward" : trajectory .reward ,
387391 "idx" : env .idx ,
388392 "chat_completions" : agent .chat_completions ,
@@ -397,6 +401,7 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
397401 "llm_time" : llm_time ,
398402 # Total time spent in the trajectory
399403 "total_time" : total_time ,
404+ "token_mismatch" : 0.0 if is_valid_trajectory else 1.0 ,
400405 },
401406 }
402407 return token_result
@@ -410,6 +415,71 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
410415 "mc_returns" : [step .mc_return for step in trajectory .steps ][: len (episode_steps )],
411416 }
412417 return steps_result
418+ else :
419+ raise ValueError (f"Mode { mode } not supported" )
420+
421+ def assemble_steps (self , steps : list [dict ]):
422+ """
423+ Transform step-by-step results into trajectory format for training.
424+ The assemble is aggresive, if steps is not cumulative, the response_masks is set to all 0s.
425+
426+ Each step_result contains:
427+ - steps: List of {"prompt": str, "response": str, "prompt_ids": list, "completion_ids": list}
428+
429+ For training, we need to assemble the full conversation sequence where:
430+ - prompt_tokens: Initial prompt (first step's prompt_ids)
431+ - response_tokens: All subsequent conversation (completion_ids + next step's prompt_ids)
432+ - response_masks: Mask indicating which tokens contribute to loss (only completion_ids)
433+ """
434+
435+ # Start with initial prompt from first step
436+ initial_prompt_ids = steps [0 ]["prompt_ids" ]
437+ accumulated_sequence = initial_prompt_ids .copy ()
438+ response_tokens = []
439+ response_masks = []
440+ is_valid_trajectory = True
441+
442+ for i , step in enumerate (steps ):
443+ current_prompt_ids = step ["prompt_ids" ]
444+ current_completion_ids = step ["completion_ids" ]
445+
446+ if i == 0 :
447+ # First step: just add completion
448+ response_tokens .extend (current_completion_ids )
449+ response_masks .extend ([1 ] * len (current_completion_ids )) # completion contributes to loss
450+ accumulated_sequence .extend (current_completion_ids )
451+ else :
452+ if current_prompt_ids [: len (accumulated_sequence )] != accumulated_sequence :
453+ # Find the first differing position
454+ prefix = current_prompt_ids [: len (accumulated_sequence )]
455+ diff_pos = None
456+ for i , (expected , actual ) in enumerate (zip (accumulated_sequence , prefix , strict = False )):
457+ if expected != actual :
458+ diff_pos = i
459+ break
460+
461+ if diff_pos is not None :
462+ logger .warning (f"When assemble steps, detect the trajectory not accumulative at position { diff_pos } . Expected: { accumulated_sequence [diff_pos : diff_pos + 5 ]} , Got: { prefix [diff_pos : diff_pos + 5 ]} . Setting response_masks to all 0s. This is likely due to retokenization." )
463+ else :
464+ logger .warning (f"When assemble steps, detect length mismatch. Expected length: { len (accumulated_sequence )} , Got length: { len (prefix )} . Setting response_masks to all 0s." )
465+
466+ is_valid_trajectory = False
467+ break
468+
469+ response_tokens .extend (current_prompt_ids [len (accumulated_sequence ) :] + current_completion_ids )
470+ response_masks .extend ([0 ] * (len (current_prompt_ids ) - len (accumulated_sequence )) + [1 ] * len (current_completion_ids )) # completion contributes to loss
471+ accumulated_sequence = current_prompt_ids + current_completion_ids
472+
473+ assert len (response_masks ) == len (response_tokens )
474+
475+ prompt_tokens = torch .tensor (initial_prompt_ids , dtype = torch .long )
476+ response_tokens = torch .tensor (response_tokens , dtype = torch .long )
477+ response_masks = torch .tensor (response_masks , dtype = torch .long )
478+
479+ if self .config .rllm .filter_token_mismatch :
480+ response_masks = response_masks * int (is_valid_trajectory )
481+
482+ return prompt_tokens , response_tokens , response_masks , is_valid_trajectory
413483
414484 async def run_agent_trajectory_with_retry (self , idx , application_id , seed = 0 , mode = "Text" , ** kwargs ):
415485 for _ in range (self .retry_limit ):
0 commit comments