Skip to content

Commit 1b157fa

Browse files
pipeline to process Episodes into TrajectoryGroup instead of regrouping
1 parent 8f82ca5 commit 1b157fa

File tree

5 files changed

+129
-145
lines changed

5 files changed

+129
-145
lines changed

examples/solver_judge_tinker/train_solver_judge_flow_tinker.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ python -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
1111
sampling.top_p=1.0 \
1212
algorithm.adv_estimator=grpo \
1313
algorithm.norm_adv_by_std_in_grpo=true \
14-
algorithm.grouping_level=step \
14+
algorithm.grouping_level=trajectory \
1515
data.max_prompt_length=2048 \
1616
data.max_response_length=1024 \
1717
data.train_batch_size=64 \

rllm/trainer/tinker/tinker_agent_trainer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from rllm.agents.agent import Episode, Step, Trajectory
2525
from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine
26-
from rllm.trainer.tinker.tinker_data_processor import TrajectoryGroup
2726
from rllm.trainer.tinker.tinker_metrics_utils import (
2827
compute_training_metrics,
2928
print_episodes,
@@ -215,13 +214,9 @@ async def _fit_agent_async(self):
215214

216215
logger.info(f"Training for batch {batch_idx}, minibatch {minibatch_count}/{num_minibatches}")
217216

218-
# Convert episodes to trajectory groups
219-
# For agent trainer, each episode becomes one group (simple conversion)
220-
trajectory_groups = [TrajectoryGroup(trajectories=episode.trajectories, group_id=episode.id if hasattr(episode, "id") else f"group_{i}") for i, episode in enumerate(minibatch_episodes)]
221-
222217
# Train immediately (streaming), only optimize on last minibatch
223218
t_train_start = time.time()
224-
logprobs, datums = await self.trainer.step(trajectory_groups, learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps, optimizer_step=False)
219+
logprobs, datums = await self.trainer.step(minibatch_episodes, learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps, optimizer_step=False)
225220
forward_backward_times.append(time.time() - t_train_start)
226221
training_logprobs.extend(logprobs)
227222
training_datums.extend(datums)

rllm/trainer/tinker/tinker_data_processor.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,96 @@ def to_datum(self) -> tinker.Datum:
396396
return datums
397397

398398

399+
def process_episodes(
400+
episodes: list,
401+
advantage_computer: TinkerAdvantageComputer,
402+
trajectory_filter: TinkerTrajectoryFilter,
403+
algorithm_config,
404+
) -> list[tinker.Datum]:
405+
"""
406+
Main pipeline to convert Episode objects to training datums.
407+
408+
This function:
409+
1. Groups trajectories based on grouping_level configuration
410+
2. Computes advantages for each group
411+
3. Builds Tinker Datums for training
412+
413+
Grouping levels:
414+
- trajectory: Group trajectories by (task_id, trajectory_name) for multi-agent workflows.
415+
Advantage computed across trajectory rewards.
416+
- step: Group individual steps at same position for step-level advantage computation.
417+
- episode: Each episode's trajectories form one group (simple single-agent case).
418+
419+
Args:
420+
episodes: List of Episode objects
421+
advantage_computer: Computer for calculating advantages
422+
trajectory_filter: Filter for removing constant-reward groups
423+
algorithm_config: Configuration with grouping_level setting
424+
425+
Returns:
426+
List of Tinker Datum objects ready for training
427+
"""
428+
from collections import defaultdict
429+
430+
grouping_level = algorithm_config.get("grouping_level", "episode")
431+
432+
# Group trajectories based on grouping_level
433+
trajectory_groups_dict = defaultdict(list)
434+
435+
def get_task_id(episode):
436+
"""Extract task_id from episode.id (format: task_id:rollout_idx)"""
437+
return ":".join(episode.id.split(":")[:-1]) if ":" in episode.id else episode.id
438+
439+
if grouping_level == "trajectory":
440+
# Group by (task_id, trajectory_name) - for multi-agent workflows like solver-judge
441+
for episode in episodes:
442+
task_id = get_task_id(episode)
443+
for trajectory in episode.trajectories:
444+
group_key = (task_id, trajectory.name)
445+
trajectory_groups_dict[group_key].append(trajectory)
446+
447+
elif grouping_level == "step":
448+
# Group by (task_id, trajectory_name, step_idx) - for step-level advantages
449+
for episode in episodes:
450+
task_id = get_task_id(episode)
451+
for trajectory in episode.trajectories:
452+
for step_idx, step in enumerate(trajectory.steps):
453+
group_key = (task_id, trajectory.name, step_idx)
454+
# Create single-step trajectory
455+
from rllm.agents.agent import Trajectory
456+
457+
single_step_traj = Trajectory(steps=[step], reward=step.reward, name=trajectory.name)
458+
trajectory_groups_dict[group_key].append(single_step_traj)
459+
460+
else: # "episode" or default
461+
# Simple grouping: all trajectories in an episode form one group
462+
for episode in episodes:
463+
group_key = episode.id
464+
trajectory_groups_dict[group_key].extend(episode.trajectories)
465+
466+
# Convert dict to list of TrajectoryGroup objects for filtering
467+
trajectory_groups = [TrajectoryGroup(trajectories=trajs, group_id=str(key)) for key, trajs in trajectory_groups_dict.items()]
468+
469+
# Apply filtering based on configuration
470+
filtered_groups = trajectory_filter.filter_groups(trajectory_groups)
471+
472+
training_datums = []
473+
for group in filtered_groups:
474+
# Extract rewards for the group (from all trajectories)
475+
group_rewards = [traj.reward for traj in group.trajectories]
476+
477+
# Compute advantages
478+
advantages = advantage_computer.compute(group_rewards)
479+
480+
# Create datums for all trajectories in the group
481+
for trajectory, advantage in zip(group.trajectories, advantages, strict=False):
482+
# Use trajectory-level building (merges steps when possible)
483+
new_datums = TinkerDatumBuilder.build_datum_from_trajectory(trajectory, advantage)
484+
training_datums.extend(new_datums)
485+
486+
return training_datums
487+
488+
399489
def process_trajectory_groups(
400490
groups: list[TrajectoryGroup],
401491
advantage_computer: TinkerAdvantageComputer,

rllm/trainer/tinker/tinker_policy_trainer.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from rllm.trainer.tinker.tinker_data_processor import (
1818
TinkerAdvantageComputer,
1919
TinkerTrajectoryFilter,
20-
TrajectoryGroup,
21-
process_trajectory_groups,
20+
process_episodes,
2221
)
2322

2423
if TYPE_CHECKING:
@@ -127,26 +126,23 @@ def _remove_mask(self, datum: tinker.Datum) -> tinker.Datum:
127126

128127
async def step(
129128
self,
130-
groups: list[TrajectoryGroup],
129+
episodes: list,
131130
learning_rate: float = None,
132131
beta1: float = 0.9,
133132
beta2: float = 0.95,
134133
eps: float = 1e-8,
135134
optimizer_step: bool = True,
136135
) -> tuple[list[torch.Tensor], list[tinker.Datum]]:
137136
"""
138-
Complete training step: process trajectory groups and update policy.
137+
Complete training step: process episodes and update policy.
139138
140139
This method:
141-
1. Converts episodes to trajectory groups if needed
142-
2. Filters groups (if configured)
143-
3. Computes advantages
144-
4. Converts to datums
145-
5. Performs forward-backward pass
146-
6. Applies optimizer step
140+
1. Processes episodes to compute advantages and create datums
141+
2. Performs forward-backward pass
142+
3. Applies optimizer step
147143
148144
Args:
149-
data: List of Episode or TrajectoryGroup objects
145+
episodes: List of Episode objects
150146
learning_rate: Learning rate (uses config value if None)
151147
optimizer_step: Whether to apply optimizer step after forward-backward
152148
@@ -159,10 +155,11 @@ async def step(
159155
learning_rate = self.config.training.learning_rate
160156

161157
# Step 1: Process to datums (includes filtering and advantage computation)
162-
training_datums = process_trajectory_groups(
163-
groups,
158+
training_datums = process_episodes(
159+
episodes,
164160
self.advantage_computer,
165161
self.trajectory_filter,
162+
self.config.algorithm,
166163
)
167164

168165
# Step 3: Remove mask from datums (not needed by forward_backward)
@@ -199,11 +196,12 @@ async def step(
199196
# Return both logprobs and datums (with masks for metrics)
200197
return training_logprobs_D, training_datums
201198

202-
async def forward_backward_future(self, groups: list[TrajectoryGroup]):
203-
training_datums = process_trajectory_groups(
204-
groups,
199+
async def forward_backward_future(self, episodes: list):
200+
training_datums = process_episodes(
201+
episodes,
205202
self.advantage_computer,
206203
self.trajectory_filter,
204+
self.config.algorithm,
207205
)
208206

209207
datums_no_mask = [self._remove_mask(datum) for datum in training_datums]

rllm/trainer/tinker/tinker_workflow_trainer.py

Lines changed: 23 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
import torch
1717
from transformers import AutoTokenizer
1818

19-
from rllm.agents.agent import Episode, Trajectory
19+
from rllm.agents.agent import Episode
2020
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
2121
from rllm.engine.rollout.tinker_engine import TinkerEngine
2222
from rllm.trainer.tinker.tinker_agent_trainer import TinkerAgentTrainer
23-
from rllm.trainer.tinker.tinker_data_processor import TrajectoryGroup
2423
from rllm.trainer.tinker.tinker_policy_trainer import TinkerPolicyTrainer
2524

2625
if TYPE_CHECKING:
@@ -119,15 +118,15 @@ def init_envs_and_agents(self, batch_data):
119118
self.current_batch = batch_data
120119

121120
async def validate_agent(self, dataloader, sampling_client):
122-
all_trajectory_groups = []
121+
all_episodes = []
123122
all_episode_metrics = {} # episode_id -> episode.metrics dict
124123
self.agent_execution_engine.rollout_engine.set_sampling_client(sampling_client)
125124
for batch in dataloader:
126125
batch = self.build_interleave_batch(batch, 1)
127126
self.init_envs_and_agents(batch)
128-
# For validation, collect all trajectory groups from generator
129-
async for trajectory_groups, episode_metrics in self.generate_agent_episodes(group_size=1, minibatch_size=1, return_metrics=True):
130-
all_trajectory_groups.extend(trajectory_groups)
127+
# For validation, collect all episodes from generator
128+
async for episodes, episode_metrics in self.generate_agent_episodes(group_size=1, minibatch_size=1, return_metrics=True):
129+
all_episodes.extend(episodes)
131130
all_episode_metrics.update(episode_metrics)
132131

133132
# Collect workflow metrics per episode (deduplicated by episode.id)
@@ -138,10 +137,10 @@ async def validate_agent(self, dataloader, sampling_client):
138137
for key, value in episode_metric_dict.items():
139138
workflow_metrics[key].append(float(value))
140139

141-
# Compute trajectory-level statistics from all groups
140+
# Compute trajectory-level statistics from all episodes
142141
all_trajectories = []
143-
for group in all_trajectory_groups:
144-
all_trajectories.extend(group.trajectories)
142+
for episode in all_episodes:
143+
all_trajectories.extend(episode.trajectories)
145144

146145
mean_reward = sum([traj.reward for traj in all_trajectories]) / len(all_trajectories)
147146
std_reward = sum([(traj.reward - mean_reward) ** 2 for traj in all_trajectories]) / len(all_trajectories)
@@ -165,17 +164,14 @@ async def validate_agent(self, dataloader, sampling_client):
165164

166165
async def generate_agent_episodes(self, timing_raw=None, meta_info=None, group_size=None, minibatch_size=None, return_metrics=False):
167166
"""
168-
Generate trajectory groups in minibatches with overlapping generation and training.
169-
170-
This uses a background producer task to continuously generate episodes (from rollout)
171-
and regroups them into TrajectoryGroup objects for advantage computation.
167+
Generate episodes from workflow execution.
172168
173169
Args:
174-
return_metrics: If True, yields (trajectory_groups, metrics) tuple where metrics is
175-
{episode_id: {metric_name: value, ...}}. If False, yields only trajectory_groups.
170+
return_metrics: If True, yields (episodes, metrics) tuple where metrics is
171+
{episode_id: {metric_name: value, ...}}. If False, yields only episodes.
176172
177173
Yields:
178-
list[TrajectoryGroup] or tuple[list[TrajectoryGroup], dict] depending on return_metrics
174+
list[Episode] or tuple[list[Episode], dict] depending on return_metrics
179175
"""
180176

181177
num_minibatches = self.config.training.num_minibatches
@@ -187,116 +183,21 @@ async def generate_agent_episodes(self, timing_raw=None, meta_info=None, group_s
187183

188184
episodes = await self.agent_execution_engine.execute_tasks(current_batch, task_ids)
189185
episodes = self.make_sure_contain_token_and_logprob(episodes)
190-
trajectory_groups, episode_metrics = self.regroup(episodes)
191-
192-
if return_metrics:
193-
yield trajectory_groups, episode_metrics
194-
else:
195-
yield trajectory_groups
196-
197-
def regroup(self, episodes: list[Episode]) -> tuple[list[TrajectoryGroup], dict]:
198-
"""
199-
Regroup episodes into TrajectoryGroup objects based on grouping_level configuration.
200-
201-
The grouping level determines how advantages are computed:
202-
203-
- trajectory: Group trajectories by (task_id, trajectory_name). Each trajectory
204-
keeps all its steps. Trajectories with different names are grouped
205-
separately (important for multi-agent scenarios). Advantage is
206-
computed per trajectory (e.g., via GRPO across trajectory rewards),
207-
then broadcast to all steps in that trajectory during datum creation.
208186

209-
- step: Group individual steps at the same position (task_id + trajectory_name
210-
+ step_idx) across different rollouts. Each step becomes a single-step
211-
trajectory in a group. Advantage is computed per step (e.g., via GRPO
212-
across step rewards).
213-
214-
The resulting TrajectoryGroup objects are consumed by process_trajectory_groups() which:
215-
1. Extracts rewards from trajectories in each group
216-
2. Computes advantages across those rewards
217-
3. Assigns each trajectory its computed advantage
218-
4. Broadcasts the advantage to all steps in the trajectory
187+
# Update trajectory-level rewards from step-level rewards
188+
for episode in episodes:
189+
for trajectory in episode.trajectories:
190+
if trajectory.reward == 0.0 and trajectory.steps:
191+
# Compute trajectory reward from step rewards
192+
trajectory.reward = sum(step.reward if step.reward is not None else 0.0 for step in trajectory.steps)
219193

220-
Args:
221-
episodes: List of episodes to regroup
194+
# Extract episode metrics if available
195+
episode_metrics = {ep.id: ep.metrics for ep in episodes if hasattr(ep, "metrics") and ep.metrics}
222196

223-
Returns:
224-
Tuple of (trajectory_groups, metrics_dict)
225-
"""
226-
grouping_level = self.config.algorithm.grouping_level
227-
trajectory_groups = []
228-
metrics = {}
229-
230-
def get_task_id(episode: Episode):
231-
return ":".join(episode.id.split(":")[:-1])
232-
233-
if grouping_level == "trajectory":
234-
# Group trajectories by (task_id, trajectory_name)
235-
# This ensures trajectories with different names are grouped separately
236-
temp_groups = defaultdict(list)
237-
238-
for episode in episodes:
239-
if episode.id not in metrics and episode.metrics:
240-
metrics[episode.id] = episode.metrics
241-
task_id = get_task_id(episode)
242-
243-
# Add all trajectories to the group for this (task_id, trajectory_name)
244-
for trajectory in episode.trajectories:
245-
# Each trajectory keeps all its steps
246-
# Compute trajectory-level reward as the sum/mean of step rewards
247-
traj_reward = trajectory.reward if trajectory.reward is not None else sum(step.reward for step in trajectory.steps)
248-
# Update trajectory with proper reward
249-
updated_trajectory = Trajectory(steps=trajectory.steps, reward=traj_reward, name=trajectory.name)
250-
# Group by both task_id and trajectory name
251-
group_key = (task_id, trajectory.name)
252-
temp_groups[group_key].append(updated_trajectory)
253-
254-
# Create TrajectoryGroup objects from grouped trajectories
255-
for group_key, trajectories in temp_groups.items():
256-
group_id = f"{group_key[0]}:{group_key[1]}" # "task_id:trajectory_name"
257-
trajectory_group = TrajectoryGroup(trajectories=trajectories, group_id=group_id)
258-
trajectory_groups.append(trajectory_group)
259-
260-
print("Trajectory-level grouping:")
261-
print(f" len episodes: {len(episodes)}")
262-
print(f" len unique (task_id, traj_name) groups: {len(temp_groups)}")
263-
print(f" len trajectory_groups: {len(trajectory_groups)}")
264-
265-
elif grouping_level == "step":
266-
# Group individual steps by step position
267-
unique_step_uids = set()
268-
unique_task_ids = set()
269-
step_groupby_step_uid = defaultdict(list)
270-
271-
for episode in episodes:
272-
if episode.id not in metrics and episode.metrics:
273-
metrics[episode.id] = episode.metrics
274-
task_id = get_task_id(episode)
275-
unique_task_ids.add(task_id)
276-
277-
for trajectory in episode.trajectories:
278-
for step_idx, step in enumerate(trajectory.steps):
279-
step_uid = f"{task_id}:{trajectory.name}:{step_idx}"
280-
if step_uid not in unique_step_uids:
281-
unique_step_uids.add(step_uid)
282-
283-
step_groupby_step_uid[step_uid].append(step)
284-
285-
# Create TrajectoryGroup objects where each trajectory contains a single step
286-
for step_uid, steps in step_groupby_step_uid.items():
287-
trajectories = [Trajectory(steps=[step], reward=step.reward) for step in steps]
288-
trajectory_group = TrajectoryGroup(trajectories=trajectories, group_id=step_uid)
289-
trajectory_groups.append(trajectory_group)
290-
291-
print("Step-level grouping:")
292-
print(f" len episodes: {len(episodes)}")
293-
print(f" len unique_task_ids: {len(unique_task_ids)}")
294-
print(f" len unique_step_uids: {len(unique_step_uids)}")
295-
print(f" len trajectory_groups: {len(trajectory_groups)}")
197+
if return_metrics:
198+
yield episodes, episode_metrics
296199
else:
297-
raise ValueError(f"Invalid grouping_level: {grouping_level}. Must be 'trajectory' or 'step'")
298-
299-
return trajectory_groups, metrics
200+
yield episodes
300201

301202
def make_sure_contain_token_and_logprob(self, episodes: list[Episode]) -> list[Episode]:
302203
for episode in episodes:

0 commit comments

Comments
 (0)