Skip to content

Commit d8b7dcf

Browse files
Merge pull request #2692 from AI-Hypercomputer:xfgu-metrics
PiperOrigin-RevId: 832449576
2 parents 3c0fe16 + ec64e49 commit d8b7dcf

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/MaxText/configs/rl.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ log_period: 20
7676
# ====== Debugging ======
7777
debug:
7878
rl: True
79+
# If True, Tunix-managed metrics measurement will be enabled. The metrics will be
80+
# uploaded to tensorboard.
81+
enable_tunix_perf_metrics: False
7982

8083
# ====== Training ======
8184
batch_size: 1

src/MaxText/rl/train_rl.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def rl_train(tmvp_config):
285285
# Setup metrics logging
286286
max_logging.log(f"Tensorboard logs directory: {tmvp_config.tensorboard_dir}")
287287
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
288-
log_dir=tmvp_config.tensorboard_dir, flush_every_n_steps=tmvp_config.log_period
288+
log_dir=tmvp_config.tensorboard_dir, flush_every_n_steps=tmvp_config.log_period
289289
)
290290

291291
profiler_options = None
@@ -335,7 +335,7 @@ def rl_train(tmvp_config):
335335
rollout_vllm_hbm_utilization=tmvp_config.hbm_utilization_vllm,
336336
rollout_vllm_tpu_backend_type="jax",
337337
rollout_vllm_swap_space_size_gb=tmvp_config.swap_space_vllm_gb,
338-
),
338+
),
339339
)
340340
grpo_config = GrpoConfig(
341341
num_generations=tmvp_config.num_generations,
@@ -347,12 +347,29 @@ def rl_train(tmvp_config):
347347

348348
# Create RL cluster
349349
max_logging.log("Creating RL cluster...")
350+
rl_cluster_kwargs = {}
351+
if tmvp_config.enable_tunix_perf_metrics:
352+
try:
353+
from tunix.perf import export as perf_export # pylint: disable=import-outside-toplevel
354+
from tunix.perf import metrics as perf_metrics # pylint: disable=import-outside-toplevel
355+
356+
max_logging.log(
357+
"enable_tunix_perf_metrics is True and tunix.perf modules are available, enabling Tunix-managed metrics."
358+
)
359+
perf_config = perf_metrics.PerfMetricsConfig()
360+
perf_config.custom_export_fn = perf_export.PerfMetricsExport.create_metrics_export_fn(cluster_config)
361+
rl_cluster_kwargs["perf_config"] = perf_config
362+
except ImportError:
363+
max_logging.log(
364+
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
365+
)
350366
with nn_partitioning.axis_rules(tmvp_config.logical_axis_rules):
351367
rl_cluster = rl_cluster_lib.RLCluster(
352368
actor=actor_model,
353369
reference=reference_model,
354370
tokenizer=model_tokenizer,
355371
cluster_config=cluster_config,
372+
**rl_cluster_kwargs,
356373
)
357374

358375
# Create GRPO trainer

0 commit comments

Comments
 (0)