Skip to content

Commit 0776768

Browse files
authored
Adding more comments to logtree in eval (#56)
1 parent 54c42ea commit 0776768

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

tinker_cookbook/rl/metric_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def __init__(
109109
max_tokens: int,
110110
name: str | None = None,
111111
num_groups_to_log: int = 4,
112-
log_path: str | None = None,
113112
):
114113
self.env_group_builders_P = dataset_to_env_group_builders(dataset)
115114
self.max_tokens = max_tokens

tinker_cookbook/rl/train.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,14 @@
4747

4848

4949
@contextmanager
50-
def get_logtree_scope(
50+
def _get_logtree_scope(
5151
log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str
5252
) -> Iterator[None]:
53+
"""
54+
Creates a context manager; all log inside this context will be logged under the section `scope_name`.
55+
It will create a file with the path of log_path/f_name.html
56+
If num_groups_to_log is 0, it will disable logging (but note that this function does not actually implement the logic for logging itself!)
57+
"""
5358
if log_path is not None and num_groups_to_log > 0:
5459
logtree_path = os.path.join(log_path, f"{f_name}.html")
5560
with logtree.init_trace(scope_name, path=logtree_path):
@@ -290,7 +295,7 @@ async def do_sync_training_with_stream_minibatch(
290295
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
291296
else ""
292297
)
293-
with get_logtree_scope(
298+
with _get_logtree_scope(
294299
log_path=cfg.log_path,
295300
num_groups_to_log=cfg.num_groups_to_log,
296301
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
@@ -299,7 +304,7 @@ async def do_sync_training_with_stream_minibatch(
299304
eval_metrics = await evaluator(sampling_client)
300305
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
301306

302-
with get_logtree_scope(
307+
with _get_logtree_scope(
303308
cfg.log_path,
304309
cfg.num_groups_to_log,
305310
f"train_iteration_{i_batch:06d}",
@@ -925,7 +930,7 @@ async def do_sync_training(
925930
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
926931
else ""
927932
)
928-
with get_logtree_scope(
933+
with _get_logtree_scope(
929934
log_path=cfg.log_path,
930935
num_groups_to_log=cfg.num_groups_to_log,
931936
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
@@ -938,7 +943,7 @@ async def do_sync_training(
938943
env_group_builders_P = dataset.get_batch(i_batch)
939944

940945
# Initialize logtree trace for this iteration if logging is enabled
941-
with get_logtree_scope(
946+
with _get_logtree_scope(
942947
log_path=cfg.log_path,
943948
num_groups_to_log=cfg.num_groups_to_log,
944949
f_name=f"train_iteration_{i_batch:06d}",

0 commit comments

Comments
 (0)