|
46 | 46 | logger = logging.getLogger(__name__) |
47 | 47 |
|
48 | 48 |
|
| 49 | +def _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str: |
| 50 | + return ( |
| 51 | + evaluator.name |
| 52 | + if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None |
| 53 | + else "" |
| 54 | + ) |
| 55 | + |
| 56 | + |
49 | 57 | @contextmanager |
50 | 58 | def _get_logtree_scope( |
51 | 59 | log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str |
@@ -254,6 +262,47 @@ class Config: |
254 | 262 | num_groups_to_log: int = 4 # Number of groups to log per iteration (0 = disable logging) |
255 | 263 |
|
256 | 264 |
|
| 265 | +@scope |
| 266 | +async def run_evaluations_parallel( |
| 267 | + evaluators: list[SamplingClientEvaluator], |
| 268 | + sampling_client: tinker.SamplingClient, |
| 269 | + cfg: Config, |
| 270 | + i_batch: int, |
| 271 | +) -> dict[str, Any]: |
| 272 | + """Run all evaluators in parallel and return aggregated metrics.""" |
| 273 | + |
| 274 | + async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): |
| 275 | + ev_name = _get_evaluator_name(evaluator) |
| 276 | + with _get_logtree_scope( |
| 277 | + log_path=cfg.log_path, |
| 278 | + num_groups_to_log=cfg.num_groups_to_log, |
| 279 | + f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", |
| 280 | + scope_name=f"Running evaluation {ev_name} {i_batch}", |
| 281 | + ): |
| 282 | + eval_metrics = await evaluator(sampling_client) |
| 283 | + return {f"test/{k}": v for k, v in eval_metrics.items()} |
| 284 | + |
| 285 | + # Create tasks for all evaluators with names for better traceability |
| 286 | + tasks = [] |
| 287 | + for i, evaluator in enumerate(evaluators): |
| 288 | + ev_name = _get_evaluator_name(evaluator) |
| 289 | + task = asyncio.create_task( |
| 290 | + run_single_evaluation(evaluator, cfg, i_batch, sampling_client), |
| 291 | + name=f"eval_{ev_name or i}_iteration_{i_batch:06d}", |
| 292 | + ) |
| 293 | + tasks.append(task) |
| 294 | + |
| 295 | + # Wait for all to complete |
| 296 | + results = await asyncio.gather(*tasks) |
| 297 | + |
| 298 | + # Merge all metrics |
| 299 | + metrics = {} |
| 300 | + for result in results: |
| 301 | + metrics.update(result) |
| 302 | + |
| 303 | + return metrics |
| 304 | + |
| 305 | + |
257 | 306 | @scope |
258 | 307 | async def do_sync_training_with_stream_minibatch( |
259 | 308 | start_batch: int, |
@@ -289,20 +338,10 @@ async def do_sync_training_with_stream_minibatch( |
289 | 338 | # Run evaluations |
290 | 339 | if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1: |
291 | 340 | with timed("run_evals", metrics): |
292 | | - for evaluator in evaluators: |
293 | | - ev_name = ( |
294 | | - evaluator.name |
295 | | - if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None |
296 | | - else "" |
297 | | - ) |
298 | | - with _get_logtree_scope( |
299 | | - log_path=cfg.log_path, |
300 | | - num_groups_to_log=cfg.num_groups_to_log, |
301 | | - f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", |
302 | | - scope_name=f"Running evaluation {ev_name} {i_batch}", |
303 | | - ): |
304 | | - eval_metrics = await evaluator(sampling_client) |
305 | | - metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) |
| 341 | + eval_metrics = await run_evaluations_parallel( |
| 342 | + evaluators, sampling_client, cfg, i_batch |
| 343 | + ) |
| 344 | + metrics.update(eval_metrics) |
306 | 345 |
|
307 | 346 | with _get_logtree_scope( |
308 | 347 | cfg.log_path, |
@@ -924,20 +963,10 @@ async def do_sync_training( |
924 | 963 | # Run evaluations |
925 | 964 | if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0: |
926 | 965 | with timed("run_evals", metrics): |
927 | | - for evaluator in evaluators: |
928 | | - ev_name = ( |
929 | | - evaluator.name |
930 | | - if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None |
931 | | - else "" |
932 | | - ) |
933 | | - with _get_logtree_scope( |
934 | | - log_path=cfg.log_path, |
935 | | - num_groups_to_log=cfg.num_groups_to_log, |
936 | | - f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", |
937 | | - scope_name=f"Running evaluation {ev_name} {i_batch}", |
938 | | - ): |
939 | | - eval_metrics = await evaluator(sampling_client) |
940 | | - metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) |
| 966 | + eval_metrics = await run_evaluations_parallel( |
| 967 | + evaluators, sampling_client, cfg, i_batch |
| 968 | + ) |
| 969 | + metrics.update(eval_metrics) |
941 | 970 |
|
942 | 971 | # Get batch and sample trajectories |
943 | 972 | env_group_builders_P = dataset.get_batch(i_batch) |
|
0 commit comments