1616import torch
1717from transformers import AutoTokenizer
1818
19- from rllm .agents .agent import Episode , Trajectory
19+ from rllm .agents .agent import Episode
2020from rllm .engine .agent_workflow_engine import AgentWorkflowEngine
2121from rllm .engine .rollout .tinker_engine import TinkerEngine
2222from rllm .trainer .tinker .tinker_agent_trainer import TinkerAgentTrainer
23- from rllm .trainer .tinker .tinker_data_processor import TrajectoryGroup
2423from rllm .trainer .tinker .tinker_policy_trainer import TinkerPolicyTrainer
2524
2625if TYPE_CHECKING :
@@ -119,15 +118,15 @@ def init_envs_and_agents(self, batch_data):
119118 self .current_batch = batch_data
120119
121120 async def validate_agent (self , dataloader , sampling_client ):
122- all_trajectory_groups = []
121+ all_episodes = []
123122 all_episode_metrics = {} # episode_id -> episode.metrics dict
124123 self .agent_execution_engine .rollout_engine .set_sampling_client (sampling_client )
125124 for batch in dataloader :
126125 batch = self .build_interleave_batch (batch , 1 )
127126 self .init_envs_and_agents (batch )
128- # For validation, collect all trajectory groups from generator
129- async for trajectory_groups , episode_metrics in self .generate_agent_episodes (group_size = 1 , minibatch_size = 1 , return_metrics = True ):
130- all_trajectory_groups .extend (trajectory_groups )
127+ # For validation, collect all episodes from generator
128+ async for episodes , episode_metrics in self .generate_agent_episodes (group_size = 1 , minibatch_size = 1 , return_metrics = True ):
129+ all_episodes .extend (episodes )
131130 all_episode_metrics .update (episode_metrics )
132131
133132 # Collect workflow metrics per episode (deduplicated by episode.id)
@@ -138,10 +137,10 @@ async def validate_agent(self, dataloader, sampling_client):
138137 for key , value in episode_metric_dict .items ():
139138 workflow_metrics [key ].append (float (value ))
140139
141- # Compute trajectory-level statistics from all groups
140+ # Compute trajectory-level statistics from all episodes
142141 all_trajectories = []
143- for group in all_trajectory_groups :
144- all_trajectories .extend (group .trajectories )
142+ for episode in all_episodes :
143+ all_trajectories .extend (episode .trajectories )
145144
146145 mean_reward = sum ([traj .reward for traj in all_trajectories ]) / len (all_trajectories )
147146 std_reward = sum ([(traj .reward - mean_reward ) ** 2 for traj in all_trajectories ]) / len (all_trajectories )
@@ -165,17 +164,14 @@ async def validate_agent(self, dataloader, sampling_client):
165164
166165 async def generate_agent_episodes (self , timing_raw = None , meta_info = None , group_size = None , minibatch_size = None , return_metrics = False ):
167166 """
168- Generate trajectory groups in minibatches with overlapping generation and training.
169-
170- This uses a background producer task to continuously generate episodes (from rollout)
171- and regroups them into TrajectoryGroup objects for advantage computation.
167+ Generate episodes from workflow execution.
172168
173169 Args:
174- return_metrics: If True, yields (trajectory_groups , metrics) tuple where metrics is
175- {episode_id: {metric_name: value, ...}}. If False, yields only trajectory_groups .
170+ return_metrics: If True, yields (episodes , metrics) tuple where metrics is
171+ {episode_id: {metric_name: value, ...}}. If False, yields only episodes .
176172
177173 Yields:
178- list[TrajectoryGroup ] or tuple[list[TrajectoryGroup ], dict] depending on return_metrics
174+ list[Episode ] or tuple[list[Episode ], dict] depending on return_metrics
179175 """
180176
181177 num_minibatches = self .config .training .num_minibatches
@@ -187,116 +183,21 @@ async def generate_agent_episodes(self, timing_raw=None, meta_info=None, group_s
187183
188184 episodes = await self .agent_execution_engine .execute_tasks (current_batch , task_ids )
189185 episodes = self .make_sure_contain_token_and_logprob (episodes )
190- trajectory_groups , episode_metrics = self .regroup (episodes )
191-
192- if return_metrics :
193- yield trajectory_groups , episode_metrics
194- else :
195- yield trajectory_groups
196-
197- def regroup (self , episodes : list [Episode ]) -> tuple [list [TrajectoryGroup ], dict ]:
198- """
199- Regroup episodes into TrajectoryGroup objects based on grouping_level configuration.
200-
201- The grouping level determines how advantages are computed:
202-
203- - trajectory: Group trajectories by (task_id, trajectory_name). Each trajectory
204- keeps all its steps. Trajectories with different names are grouped
205- separately (important for multi-agent scenarios). Advantage is
206- computed per trajectory (e.g., via GRPO across trajectory rewards),
207- then broadcast to all steps in that trajectory during datum creation.
208186
209- - step: Group individual steps at the same position (task_id + trajectory_name
210- + step_idx) across different rollouts. Each step becomes a single-step
211- trajectory in a group. Advantage is computed per step (e.g., via GRPO
212- across step rewards).
213-
214- The resulting TrajectoryGroup objects are consumed by process_trajectory_groups() which:
215- 1. Extracts rewards from trajectories in each group
216- 2. Computes advantages across those rewards
217- 3. Assigns each trajectory its computed advantage
218- 4. Broadcasts the advantage to all steps in the trajectory
187+ # Update trajectory-level rewards from step-level rewards
188+ for episode in episodes :
189+ for trajectory in episode .trajectories :
190+ if trajectory .reward == 0.0 and trajectory .steps :
191+ # Compute trajectory reward from step rewards
192+ trajectory .reward = sum (step .reward if step .reward is not None else 0.0 for step in trajectory .steps )
219193
220- Args:
221- episodes: List of episodes to regroup
194+ # Extract episode metrics if available
195+ episode_metrics = { ep . id : ep . metrics for ep in episodes if hasattr ( ep , "metrics" ) and ep . metrics }
222196
223- Returns:
224- Tuple of (trajectory_groups, metrics_dict)
225- """
226- grouping_level = self .config .algorithm .grouping_level
227- trajectory_groups = []
228- metrics = {}
229-
230- def get_task_id (episode : Episode ):
231- return ":" .join (episode .id .split (":" )[:- 1 ])
232-
233- if grouping_level == "trajectory" :
234- # Group trajectories by (task_id, trajectory_name)
235- # This ensures trajectories with different names are grouped separately
236- temp_groups = defaultdict (list )
237-
238- for episode in episodes :
239- if episode .id not in metrics and episode .metrics :
240- metrics [episode .id ] = episode .metrics
241- task_id = get_task_id (episode )
242-
243- # Add all trajectories to the group for this (task_id, trajectory_name)
244- for trajectory in episode .trajectories :
245- # Each trajectory keeps all its steps
246- # Compute trajectory-level reward as the sum/mean of step rewards
247- traj_reward = trajectory .reward if trajectory .reward is not None else sum (step .reward for step in trajectory .steps )
248- # Update trajectory with proper reward
249- updated_trajectory = Trajectory (steps = trajectory .steps , reward = traj_reward , name = trajectory .name )
250- # Group by both task_id and trajectory name
251- group_key = (task_id , trajectory .name )
252- temp_groups [group_key ].append (updated_trajectory )
253-
254- # Create TrajectoryGroup objects from grouped trajectories
255- for group_key , trajectories in temp_groups .items ():
256- group_id = f"{ group_key [0 ]} :{ group_key [1 ]} " # "task_id:trajectory_name"
257- trajectory_group = TrajectoryGroup (trajectories = trajectories , group_id = group_id )
258- trajectory_groups .append (trajectory_group )
259-
260- print ("Trajectory-level grouping:" )
261- print (f" len episodes: { len (episodes )} " )
262- print (f" len unique (task_id, traj_name) groups: { len (temp_groups )} " )
263- print (f" len trajectory_groups: { len (trajectory_groups )} " )
264-
265- elif grouping_level == "step" :
266- # Group individual steps by step position
267- unique_step_uids = set ()
268- unique_task_ids = set ()
269- step_groupby_step_uid = defaultdict (list )
270-
271- for episode in episodes :
272- if episode .id not in metrics and episode .metrics :
273- metrics [episode .id ] = episode .metrics
274- task_id = get_task_id (episode )
275- unique_task_ids .add (task_id )
276-
277- for trajectory in episode .trajectories :
278- for step_idx , step in enumerate (trajectory .steps ):
279- step_uid = f"{ task_id } :{ trajectory .name } :{ step_idx } "
280- if step_uid not in unique_step_uids :
281- unique_step_uids .add (step_uid )
282-
283- step_groupby_step_uid [step_uid ].append (step )
284-
285- # Create TrajectoryGroup objects where each trajectory contains a single step
286- for step_uid , steps in step_groupby_step_uid .items ():
287- trajectories = [Trajectory (steps = [step ], reward = step .reward ) for step in steps ]
288- trajectory_group = TrajectoryGroup (trajectories = trajectories , group_id = step_uid )
289- trajectory_groups .append (trajectory_group )
290-
291- print ("Step-level grouping:" )
292- print (f" len episodes: { len (episodes )} " )
293- print (f" len unique_task_ids: { len (unique_task_ids )} " )
294- print (f" len unique_step_uids: { len (unique_step_uids )} " )
295- print (f" len trajectory_groups: { len (trajectory_groups )} " )
197+ if return_metrics :
198+ yield episodes , episode_metrics
296199 else :
297- raise ValueError (f"Invalid grouping_level: { grouping_level } . Must be 'trajectory' or 'step'" )
298-
299- return trajectory_groups , metrics
200+ yield episodes
300201
301202 def make_sure_contain_token_and_logprob (self , episodes : list [Episode ]) -> list [Episode ]:
302203 for episode in episodes :
0 commit comments