Skip to content

Commit 54c42ea

Browse files
authored
[logtree for RLTestEvaluator] (#54)
1 parent 51d9e82 commit 54c42ea

File tree

4 files changed

+95
-90
lines changed

4 files changed

+95
-90
lines changed

tinker_cookbook/rl/metric_util.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tinker_cookbook.rl.rollouts import do_group_rollout
1111
from tinker_cookbook.rl.types import EnvGroupBuilder, RLDataset, TrajectoryGroup
1212
from tinker_cookbook.utils.misc_utils import all_same, dict_mean
13+
from tinker_cookbook.utils import logtree
1314

1415

1516
def _compute_by_group_metrics(trajectory_groups_P: List[TrajectoryGroup], good_thresh: float = 0.5):
@@ -102,15 +103,29 @@ def dataset_to_env_group_builders(dataset: RLDataset) -> list[EnvGroupBuilder]:
102103

103104

104105
class RLTestSetEvaluator(SamplingClientEvaluator):
105-
def __init__(self, dataset: RLDataset, max_tokens: int, name: str | None = None):
106+
def __init__(
107+
self,
108+
dataset: RLDataset,
109+
max_tokens: int,
110+
name: str | None = None,
111+
num_groups_to_log: int = 4,
112+
log_path: str | None = None,
113+
):
106114
self.env_group_builders_P = dataset_to_env_group_builders(dataset)
107115
self.max_tokens = max_tokens
108116
self.name = name
117+
self.num_groups_to_log = num_groups_to_log
109118

110119
async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:
111120
policy = TinkerTokenCompleter(sampling_client, max_tokens=self.max_tokens)
121+
122+
async def run_group_rollout(builder, i):
123+
enable_logging = i < self.num_groups_to_log
124+
with logtree.optional_enable_logging(enable=enable_logging):
125+
return await do_group_rollout(builder, policy)
126+
112127
trajectory_groups_P = await asyncio.gather(
113-
*[do_group_rollout(builder, policy) for builder in self.env_group_builders_P]
128+
*[run_group_rollout(builder, i) for i, builder in enumerate(self.env_group_builders_P)]
114129
)
115130
taglist_P = [builder.logging_tags() for builder in self.env_group_builders_P]
116131
metrics = compute_trajectory_metrics(trajectory_groups_P, taglist_P)

tinker_cookbook/rl/train.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import logging
88
import os
99
import time
10-
from contextlib import nullcontext
11-
from typing import Any, Callable, List, Literal, Sequence
10+
from typing import Any, Callable, List, Literal, Sequence, Iterator
1211

1312
import chz
1413
import numpy as np
@@ -41,10 +40,24 @@
4140
from tinker_cookbook.utils import logtree, ml_log
4241
from tinker_cookbook.utils.misc_utils import safezip, split_list, timed
4342
from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context
43+
from contextlib import contextmanager
44+
4445

4546
logger = logging.getLogger(__name__)
4647

4748

49+
@contextmanager
50+
def get_logtree_scope(
51+
log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str
52+
) -> Iterator[None]:
53+
if log_path is not None and num_groups_to_log > 0:
54+
logtree_path = os.path.join(log_path, f"{f_name}.html")
55+
with logtree.init_trace(scope_name, path=logtree_path):
56+
yield
57+
else:
58+
yield
59+
60+
4861
@scope
4962
def _select_representative_inds(scores: list[float], num_inds: int) -> list[int]:
5063
assert num_inds <= len(scores)
@@ -272,19 +285,25 @@ async def do_sync_training_with_stream_minibatch(
272285
if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1:
273286
with timed("run_evals", metrics):
274287
for evaluator in evaluators:
275-
eval_metrics = await evaluator(sampling_client)
276-
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
288+
ev_name = (
289+
evaluator.name
290+
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
291+
else ""
292+
)
293+
with get_logtree_scope(
294+
log_path=cfg.log_path,
295+
num_groups_to_log=cfg.num_groups_to_log,
296+
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
297+
scope_name=f"Running evaluation {ev_name} {i_batch}",
298+
):
299+
eval_metrics = await evaluator(sampling_client)
300+
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
277301

278-
# Initialize logtree trace for this iteration if logging is enabled
279-
logtree_path = (
280-
os.path.join(cfg.log_path, f"iteration_{i_batch:06d}.html")
281-
if cfg.num_groups_to_log > 0
282-
else None
283-
)
284-
with (
285-
logtree.init_trace(f"RL Iteration {i_batch}", path=logtree_path)
286-
if logtree_path
287-
else logtree.scope_disable()
302+
with get_logtree_scope(
303+
cfg.log_path,
304+
cfg.num_groups_to_log,
305+
f"train_iteration_{i_batch:06d}",
306+
f"RL Iteration {i_batch}",
288307
):
289308
# Samplers will produce trajectory groups asynchronously,
290309
# and the trainer will consume them as soon as they are ready
@@ -598,7 +617,8 @@ async def do_group_rollout_and_filter_constant_reward(
598617
enable_logging: bool = True,
599618
) -> TrajectoryGroup | None:
600619
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens)
601-
with nullcontext() if enable_logging else logtree.scope_disable():
620+
621+
with logtree.optional_enable_logging(enable_logging):
602622
trajectory_group = await do_group_rollout(env_group_builder, policy)
603623

604624
# Remove if all trajectories have the same reward
@@ -900,23 +920,29 @@ async def do_sync_training(
900920
if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0:
901921
with timed("run_evals", metrics):
902922
for evaluator in evaluators:
903-
eval_metrics = await evaluator(sampling_client)
904-
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
923+
ev_name = (
924+
evaluator.name
925+
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
926+
else ""
927+
)
928+
with get_logtree_scope(
929+
log_path=cfg.log_path,
930+
num_groups_to_log=cfg.num_groups_to_log,
931+
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
932+
scope_name=f"Running evaluation {ev_name} {i_batch}",
933+
):
934+
eval_metrics = await evaluator(sampling_client)
935+
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
905936

906937
# Get batch and sample trajectories
907938
env_group_builders_P = dataset.get_batch(i_batch)
908939

909940
# Initialize logtree trace for this iteration if logging is enabled
910-
logtree_path = (
911-
os.path.join(cfg.log_path, f"iteration_{i_batch:06d}.html")
912-
if cfg.num_groups_to_log > 0
913-
else None
914-
)
915-
with (
916-
logtree.init_trace(f"RL Iteration {i_batch}", path=logtree_path)
917-
if logtree_path
918-
else logtree.scope_disable(),
919-
timed("sample", metrics),
941+
with get_logtree_scope(
942+
log_path=cfg.log_path,
943+
num_groups_to_log=cfg.num_groups_to_log,
944+
f_name=f"train_iteration_{i_batch:06d}",
945+
scope_name=f"RL Iteration {i_batch}",
920946
):
921947
trajectory_groups_P = await asyncio.gather(
922948
*[

tinker_cookbook/tests/test_logtree.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,26 +175,19 @@ def simple_function():
175175
def custom_title_function():
176176
logtree.log_text("Inside custom title function")
177177

178-
@logtree.scope_header_decorator(lambda x: f"Processing {x}")
179-
def dynamic_title_function(x):
180-
logtree.log_text(f"Value: {x}")
181-
182178
with tempfile.TemporaryDirectory() as tmpdir:
183179
output_path = Path(tmpdir) / "decorator.html"
184180

185181
with logtree.init_trace("Decorator Test", path=output_path):
186182
simple_function()
187183
custom_title_function()
188-
dynamic_title_function(42)
189184

190185
content = output_path.read_text()
191186

192187
assert "simple_function" in content
193188
assert "Inside simple function" in content
194189
assert "Custom Title" in content
195190
assert "Inside custom title function" in content
196-
assert "Processing 42" in content
197-
assert "Value: 42" in content
198191

199192

200193
async def async_test_scope_header_decorator():

tinker_cookbook/utils/logtree.py

Lines changed: 25 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,6 @@ def init_trace(
435435
_current_trace.reset(tok_t)
436436

437437

438-
# Public API: Structure
439-
440-
441438
@contextmanager
442439
def scope_header(title: str, **attrs: Any) -> Iterator[None]:
443440
"""
@@ -480,51 +477,37 @@ def scope_header(title: str, **attrs: Any) -> Iterator[None]:
480477
F = TypeVar("F", bound=Callable[..., Any])
481478

482479

483-
# Overloads to support both bare decorator and parameterized usage
484-
# More specific overloads must come first
480+
# Overloads the parameterized usage
485481
@overload
486-
def scope_header_decorator(title: str, **attrs: Any) -> Callable[[F], F]: ... # String title
487-
488-
489-
@overload
490-
def scope_header_decorator(
491-
title: Callable[..., str], **attrs: Any
492-
) -> Callable[[F], F]: ... # Lambda title
493-
494-
495-
@overload
496-
def scope_header_decorator(title: None = None, **attrs: Any) -> Callable[[F], F]: ... # No args
482+
def scope_header_decorator(title: str) -> Callable[[F], F]: ... # String title
497483

498484

485+
# Overloads the bare usage
499486
@overload
500487
def scope_header_decorator(title: F) -> F: ... # Bare: @scope_header_decorator
501488

502489

503490
def scope_header_decorator(
504-
title: str | Callable[..., str] | F | None = None, **attrs: Any
491+
title: str | F,
505492
) -> F | Callable[[F], F]:
506493
"""
507494
Decorator to wrap function in a scope_header.
508495
509496
Args:
510-
title: String, callable returning string, or None (use function name)
511-
**attrs: HTML attributes
497+
title: String or function returning string
512498
513499
Examples:
514500
@logtree.scope_header_decorator
515501
async def process_batch():
516502
...
517503
518-
@logtree.scope_header_decorator("Processing")
519-
def work():
520-
...
521-
522-
@logtree.scope_header_decorator(lambda self, x: f"Item {x}")
523-
def handle_item(self, x):
504+
@logtree.scope_header_decorator("Handling item")
505+
def handle_item():
524506
...
525507
"""
508+
title_str = title if isinstance(title, str) else title.__name__
526509

527-
def _wrap(fn: F, title_fn: Callable[..., str]) -> F:
510+
def _wrap(fn: F) -> F:
528511
if inspect.iscoroutinefunction(fn):
529512

530513
@functools.wraps(fn)
@@ -533,7 +516,7 @@ async def aw(*args: Any, **kwargs: Any) -> Any:
533516
if not _is_logging_enabled():
534517
return await fn(*args, **kwargs)
535518

536-
with scope_header(title_fn(*args, **kwargs), **attrs):
519+
with scope_header(title_str):
537520
return await fn(*args, **kwargs)
538521

539522
return aw # type: ignore
@@ -545,38 +528,16 @@ def w(*args: Any, **kwargs: Any) -> Any:
545528
if not _is_logging_enabled():
546529
return fn(*args, **kwargs)
547530

548-
with scope_header(title_fn(*args, **kwargs), **attrs):
531+
with scope_header(title_str):
549532
return fn(*args, **kwargs)
550533

551534
return w # type: ignore
552535

553-
# Check if this is a bare decorator (no arguments)
554-
# When used as @scope_header_decorator, title will be the decorated function.
555-
# When used as @scope_header_decorator("string") or @scope_header_decorator(lambda ...),
556-
# title will be a string or lambda, and this returns a decorator function.
557-
# We distinguish by checking if it's a function but NOT a lambda.
558-
if inspect.isfunction(title) and title.__name__ != "<lambda>" and not attrs:
559-
# Bare decoration: @scope_header_decorator
536+
if isinstance(title, str):
537+
return _wrap
538+
else:
560539
fn = title
561-
return _wrap(fn, lambda *_args, **_kwargs: fn.__name__) # type: ignore[arg-type]
562-
563-
# Parameterized decorator
564-
def deco(fn: F) -> F:
565-
if title is None:
566-
567-
def title_fn(*_args: Any, **_kwargs: Any) -> str:
568-
return fn.__name__
569-
570-
elif callable(title):
571-
title_fn = title
572-
else:
573-
574-
def title_fn(*_args: Any, **_kwargs: Any) -> str:
575-
return str(title)
576-
577-
return _wrap(fn, title_fn)
578-
579-
return deco
540+
return _wrap(fn)
580541

581542

582543
@contextmanager
@@ -618,6 +579,16 @@ def scope_disable() -> Iterator[None]:
618579
_logging_disabled.reset(token)
619580

620581

582+
@contextmanager
583+
def optional_enable_logging(enable: bool) -> Iterator[None]:
584+
"""Context manager to optionally enable logging."""
585+
if enable:
586+
yield
587+
else:
588+
with scope_disable():
589+
yield
590+
591+
621592
@contextmanager
622593
def scope_details(summary: str) -> Iterator[None]:
623594
"""

0 commit comments

Comments
 (0)