Skip to content

Commit 808cc08

Browse files
add metrics for grouping and advantage stats
1 parent 1b157fa commit 808cc08

File tree

3 files changed

+55
-11
lines changed

3 files changed

+55
-11
lines changed

rllm/trainer/tinker/tinker_agent_trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ async def _fit_agent_async(self):
201201

202202
# Stream: train on each minibatch as it arrives
203203
train_step_start = time.time()
204+
all_grouping_metrics = []
204205
async for minibatch_episodes in self.generate_agent_episodes(group_size=self.config.training.group_size, minibatch_size=minibatch_size):
205206
episodes.extend(minibatch_episodes)
206207
minibatch_count += 1
@@ -216,10 +217,11 @@ async def _fit_agent_async(self):
216217

217218
# Train immediately (streaming), only optimize on last minibatch
218219
t_train_start = time.time()
219-
logprobs, datums = await self.trainer.step(minibatch_episodes, learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps, optimizer_step=False)
220+
logprobs, datums, grouping_metrics = await self.trainer.step(minibatch_episodes, learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps, optimizer_step=False)
220221
forward_backward_times.append(time.time() - t_train_start)
221222
training_logprobs.extend(logprobs)
222223
training_datums.extend(datums)
224+
all_grouping_metrics.append(grouping_metrics)
223225
logger.info(f"Processed minibatch {minibatch_count}/{num_minibatches} with {len(minibatch_episodes)} episodes")
224226

225227
optim_step_time = time.time()
@@ -246,6 +248,19 @@ async def _fit_agent_async(self):
246248
training_datums=training_datums, # Pass datums for KL/perplexity metrics
247249
training_logprobs=training_logprobs,
248250
)
251+
252+
# Aggregate grouping metrics from all minibatches
253+
if all_grouping_metrics:
254+
import numpy as np
255+
256+
# Average numeric metrics across minibatches
257+
aggregated_grouping_metrics = {}
258+
for key in all_grouping_metrics[0].keys():
259+
values = [m[key] for m in all_grouping_metrics if key in m]
260+
if values:
261+
aggregated_grouping_metrics[key] = np.mean(values)
262+
metrics.update(aggregated_grouping_metrics)
263+
249264
tracking_logger.log(data=metrics, step=batch_idx)
250265
print_metrics_table(metrics, batch_idx)
251266

rllm/trainer/tinker/tinker_data_processor.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def process_episodes(
401401
advantage_computer: TinkerAdvantageComputer,
402402
trajectory_filter: TinkerTrajectoryFilter,
403403
algorithm_config,
404-
) -> list[tinker.Datum]:
404+
) -> tuple[list[tinker.Datum], dict]:
405405
"""
406406
Main pipeline to convert Episode objects to training datums.
407407
@@ -423,10 +423,14 @@ def process_episodes(
423423
algorithm_config: Configuration with grouping_level setting
424424
425425
Returns:
426-
List of Tinker Datum objects ready for training
426+
Tuple of (datums, metrics_dict):
427+
- datums: List of Tinker Datum objects ready for training
428+
- metrics_dict: Dictionary with grouping and advantage statistics
427429
"""
428430
from collections import defaultdict
429431

432+
import numpy as np
433+
430434
grouping_level = algorithm_config.get("grouping_level", "episode")
431435

432436
# Group trajectories based on grouping_level
@@ -469,6 +473,10 @@ def get_task_id(episode):
469473
# Apply filtering based on configuration
470474
filtered_groups = trajectory_filter.filter_groups(trajectory_groups)
471475

476+
# Track metrics
477+
all_advantages = []
478+
group_sizes = []
479+
472480
training_datums = []
473481
for group in filtered_groups:
474482
# Extract rewards for the group (from all trajectories)
@@ -477,13 +485,33 @@ def get_task_id(episode):
477485
# Compute advantages
478486
advantages = advantage_computer.compute(group_rewards)
479487

488+
# Track for metrics
489+
all_advantages.extend(advantages)
490+
group_sizes.append(len(group.trajectories))
491+
480492
# Create datums for all trajectories in the group
481493
for trajectory, advantage in zip(group.trajectories, advantages, strict=False):
482494
# Use trajectory-level building (merges steps when possible)
483495
new_datums = TinkerDatumBuilder.build_datum_from_trajectory(trajectory, advantage)
484496
training_datums.extend(new_datums)
485497

486-
return training_datums
498+
# Compute grouping and advantage metrics
499+
metrics = {}
500+
if filtered_groups:
501+
metrics["grouping/num_groups"] = len(filtered_groups)
502+
metrics["grouping/num_groups_before_filter"] = len(trajectory_groups)
503+
metrics["grouping/avg_group_size"] = np.mean(group_sizes)
504+
metrics["grouping/max_group_size"] = np.max(group_sizes)
505+
metrics["grouping/min_group_size"] = np.min(group_sizes)
506+
507+
if all_advantages:
508+
metrics["advantage/mean"] = np.mean(all_advantages)
509+
metrics["advantage/std"] = np.std(all_advantages)
510+
metrics["advantage/max"] = np.max(all_advantages)
511+
metrics["advantage/min"] = np.min(all_advantages)
512+
metrics["advantage/fraction_zero"] = np.sum(np.abs(all_advantages) < 1e-8) / len(all_advantages)
513+
514+
return training_datums, metrics
487515

488516

489517
def process_trajectory_groups(

rllm/trainer/tinker/tinker_policy_trainer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ async def step(
132132
beta2: float = 0.95,
133133
eps: float = 1e-8,
134134
optimizer_step: bool = True,
135-
) -> tuple[list[torch.Tensor], list[tinker.Datum]]:
135+
) -> tuple[list[torch.Tensor], list[tinker.Datum], dict]:
136136
"""
137137
Complete training step: process episodes and update policy.
138138
@@ -147,15 +147,16 @@ async def step(
147147
optimizer_step: Whether to apply optimizer step after forward-backward
148148
149149
Returns:
150-
Tuple of (training_logprobs, training_datums)
150+
Tuple of (training_logprobs, training_datums, grouping_metrics)
151151
- training_logprobs: List of training logprobs for KL computation
152152
- training_datums: List of datums WITH masks for metrics
153+
- grouping_metrics: Dict with grouping and advantage statistics
153154
"""
154155
if learning_rate is None:
155156
learning_rate = self.config.training.learning_rate
156157

157158
# Step 1: Process to datums (includes filtering and advantage computation)
158-
training_datums = process_episodes(
159+
training_datums, grouping_metrics = process_episodes(
159160
episodes,
160161
self.advantage_computer,
161162
self.trajectory_filter,
@@ -193,11 +194,11 @@ async def step(
193194
training_logprobs = output["logprobs"].to_torch()
194195
training_logprobs_D.append(training_logprobs)
195196

196-
# Return both logprobs and datums (with masks for metrics)
197-
return training_logprobs_D, training_datums
197+
# Return logprobs, datums (with masks for metrics), and grouping metrics
198+
return training_logprobs_D, training_datums, grouping_metrics
198199

199200
async def forward_backward_future(self, episodes: list):
200-
training_datums = process_episodes(
201+
training_datums, grouping_metrics = process_episodes(
201202
episodes,
202203
self.advantage_computer,
203204
self.trajectory_filter,
@@ -211,7 +212,7 @@ async def forward_backward_future(self, episodes: list):
211212
loss_fn="importance_sampling",
212213
)
213214

214-
return fwd_bwd_future
215+
return fwd_bwd_future, grouping_metrics
215216

216217
async def optim_step_future(self, learning_rate: float = None, beta1: float = 0.9, beta2: float = 0.95, eps: float = 1e-8):
217218
if learning_rate is None:

0 commit comments

Comments
 (0)