Skip to content

Commit 406e6fe

Browse files
authored
[Tracing the single evaluation name] (#67)
1 parent d19c3d5 commit 406e6fe

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

tinker_cookbook/rl/train.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,19 @@ class Config:
262262
num_groups_to_log: int = 4 # Number of groups to log per iteration (0 = disable logging)
263263

264264

265+
@scope
266+
async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client):
267+
ev_name = _get_evaluator_name(evaluator)
268+
with _get_logtree_scope(
269+
log_path=cfg.log_path,
270+
num_groups_to_log=cfg.num_groups_to_log,
271+
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
272+
scope_name=f"Running evaluation {ev_name} {i_batch}",
273+
):
274+
eval_metrics = await evaluator(sampling_client)
275+
return {f"test/{k}": v for k, v in eval_metrics.items()}
276+
277+
265278
@scope
266279
async def run_evaluations_parallel(
267280
evaluators: list[SamplingClientEvaluator],
@@ -271,17 +284,6 @@ async def run_evaluations_parallel(
271284
) -> dict[str, Any]:
272285
"""Run all evaluators in parallel and return aggregated metrics."""
273286

274-
async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client):
275-
ev_name = _get_evaluator_name(evaluator)
276-
with _get_logtree_scope(
277-
log_path=cfg.log_path,
278-
num_groups_to_log=cfg.num_groups_to_log,
279-
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
280-
scope_name=f"Running evaluation {ev_name} {i_batch}",
281-
):
282-
eval_metrics = await evaluator(sampling_client)
283-
return {f"test/{k}": v for k, v in eval_metrics.items()}
284-
285287
# Create tasks for all evaluators with names for better traceability
286288
tasks = []
287289
for i, evaluator in enumerate(evaluators):

0 commit comments

Comments
 (0)