Skip to content

Commit c0179eb

Browse files
authored
parallelize evaluation in rl-training (#60)
1 parent 42ac624 commit c0179eb

File tree

1 file changed

+57
-28
lines changed

1 file changed

+57
-28
lines changed

tinker_cookbook/rl/train.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@
4646
logger = logging.getLogger(__name__)
4747

4848

49+
def _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str:
50+
return (
51+
evaluator.name
52+
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
53+
else ""
54+
)
55+
56+
4957
@contextmanager
5058
def _get_logtree_scope(
5159
log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str
@@ -254,6 +262,47 @@ class Config:
254262
num_groups_to_log: int = 4 # Number of groups to log per iteration (0 = disable logging)
255263

256264

265+
@scope
266+
async def run_evaluations_parallel(
267+
evaluators: list[SamplingClientEvaluator],
268+
sampling_client: tinker.SamplingClient,
269+
cfg: Config,
270+
i_batch: int,
271+
) -> dict[str, Any]:
272+
"""Run all evaluators in parallel and return aggregated metrics."""
273+
274+
async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client):
275+
ev_name = _get_evaluator_name(evaluator)
276+
with _get_logtree_scope(
277+
log_path=cfg.log_path,
278+
num_groups_to_log=cfg.num_groups_to_log,
279+
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
280+
scope_name=f"Running evaluation {ev_name} {i_batch}",
281+
):
282+
eval_metrics = await evaluator(sampling_client)
283+
return {f"test/{k}": v for k, v in eval_metrics.items()}
284+
285+
# Create tasks for all evaluators with names for better traceability
286+
tasks = []
287+
for i, evaluator in enumerate(evaluators):
288+
ev_name = _get_evaluator_name(evaluator)
289+
task = asyncio.create_task(
290+
run_single_evaluation(evaluator, cfg, i_batch, sampling_client),
291+
name=f"eval_{ev_name or i}_iteration_{i_batch:06d}",
292+
)
293+
tasks.append(task)
294+
295+
# Wait for all to complete
296+
results = await asyncio.gather(*tasks)
297+
298+
# Merge all metrics
299+
metrics = {}
300+
for result in results:
301+
metrics.update(result)
302+
303+
return metrics
304+
305+
257306
@scope
258307
async def do_sync_training_with_stream_minibatch(
259308
start_batch: int,
@@ -289,20 +338,10 @@ async def do_sync_training_with_stream_minibatch(
289338
# Run evaluations
290339
if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1:
291340
with timed("run_evals", metrics):
292-
for evaluator in evaluators:
293-
ev_name = (
294-
evaluator.name
295-
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
296-
else ""
297-
)
298-
with _get_logtree_scope(
299-
log_path=cfg.log_path,
300-
num_groups_to_log=cfg.num_groups_to_log,
301-
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
302-
scope_name=f"Running evaluation {ev_name} {i_batch}",
303-
):
304-
eval_metrics = await evaluator(sampling_client)
305-
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
341+
eval_metrics = await run_evaluations_parallel(
342+
evaluators, sampling_client, cfg, i_batch
343+
)
344+
metrics.update(eval_metrics)
306345

307346
with _get_logtree_scope(
308347
cfg.log_path,
@@ -924,20 +963,10 @@ async def do_sync_training(
924963
# Run evaluations
925964
if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0:
926965
with timed("run_evals", metrics):
927-
for evaluator in evaluators:
928-
ev_name = (
929-
evaluator.name
930-
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
931-
else ""
932-
)
933-
with _get_logtree_scope(
934-
log_path=cfg.log_path,
935-
num_groups_to_log=cfg.num_groups_to_log,
936-
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
937-
scope_name=f"Running evaluation {ev_name} {i_batch}",
938-
):
939-
eval_metrics = await evaluator(sampling_client)
940-
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
966+
eval_metrics = await run_evaluations_parallel(
967+
evaluators, sampling_client, cfg, i_batch
968+
)
969+
metrics.update(eval_metrics)
941970

942971
# Get batch and sample trajectories
943972
env_group_builders_P = dataset.get_batch(i_batch)

0 commit comments

Comments
 (0)