@@ -285,7 +285,7 @@ def rl_train(tmvp_config):
285285 # Setup metrics logging
286286 max_logging .log (f"Tensorboard logs directory: { tmvp_config .tensorboard_dir } " )
287287 metrics_logging_options = metrics_logger .MetricsLoggerOptions (
288- log_dir = tmvp_config .tensorboard_dir , flush_every_n_steps = tmvp_config .log_period
288+ log_dir = tmvp_config .tensorboard_dir , flush_every_n_steps = tmvp_config .log_period
289289 )
290290
291291 profiler_options = None
@@ -335,7 +335,7 @@ def rl_train(tmvp_config):
335335 rollout_vllm_hbm_utilization = tmvp_config .hbm_utilization_vllm ,
336336 rollout_vllm_tpu_backend_type = "jax" ,
337337 rollout_vllm_swap_space_size_gb = tmvp_config .swap_space_vllm_gb ,
338- ),
338+ ),
339339 )
340340 grpo_config = GrpoConfig (
341341 num_generations = tmvp_config .num_generations ,
@@ -347,12 +347,29 @@ def rl_train(tmvp_config):
347347
348348 # Create RL cluster
349349 max_logging .log ("Creating RL cluster..." )
350+ rl_cluster_kwargs = {}
351+ if tmvp_config .enable_tunix_perf_metrics :
352+ try :
353+ from tunix .perf import export as perf_export # pylint: disable=import-outside-toplevel
354+ from tunix .perf import metrics as perf_metrics # pylint: disable=import-outside-toplevel
355+
356+ max_logging .log (
357+ "enable_tunix_perf_metrics is True and tunix.perf modules are available, enabling Tunix-managed metrics."
358+ )
359+ perf_config = perf_metrics .PerfMetricsConfig ()
360+ perf_config .custom_export_fn = perf_export .PerfMetricsExport .create_metrics_export_fn (cluster_config )
361+ rl_cluster_kwargs ["perf_config" ] = perf_config
362+ except ImportError :
363+ max_logging .log (
364+ "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
365+ )
350366 with nn_partitioning .axis_rules (tmvp_config .logical_axis_rules ):
351367 rl_cluster = rl_cluster_lib .RLCluster (
352368 actor = actor_model ,
353369 reference = reference_model ,
354370 tokenizer = model_tokenizer ,
355371 cluster_config = cluster_config ,
372+ ** rl_cluster_kwargs ,
356373 )
357374
358375 # Create GRPO trainer
0 commit comments