Skip to content

Commit 7f8ed85

Browse files
committed
temp remove profiler (de)activation logging
1 parent 188c426 commit 7f8ed85

File tree

2 files changed

+84
-33
lines changed

2 files changed

+84
-33
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ upload_all_profiler_results: False
624624
skip_first_n_steps_for_profiler: 1
625625
# Profile for a small number of steps to avoid a large profile file size.
626626
profiler_steps: 5
627+
hide_profiler_step_metric: False
627628
profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step.
628629
profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps.
629630
# This is useful to debug scenarios where performance is changing.

src/MaxText/metric_logger.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -104,43 +104,93 @@ def write_metrics(self, metrics, step, is_training=True):
104104
def log_metrics(self, metrics, step, is_training):
105105
"""Logs metrics via max_logging."""
106106
if is_training:
107-
loss = metrics["scalar"]["learning/loss"]
108-
# Do not show flops and tokens during batch size rampup
109-
if step >= self.config.rampup_end_step:
110-
log_message = (
111-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
112-
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
113-
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
114-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
115-
f"loss: {loss:.3f}"
116-
)
117-
else:
118-
log_message = (
119-
"[Rampup Batch Size Phase]: "
120-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
121-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
122-
f"loss: {loss:.3f}"
123-
)
124-
125-
if self.config.mtp_num_layers > 0:
126-
mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0)
127-
main_model_loss = loss - mtp_loss
128-
log_message += f", main_model_loss: {main_model_loss:.3f}, mtp_loss: {mtp_loss:.3f}"
129-
107+
self._log_training_metrics(metrics, step)
108+
else:
109+
self._log_eval_metrics(metrics, step)
110+
111+
def _log_training_metrics(self, metrics, step):
112+
"""Handles training-specific metric logging."""
113+
# Skip logging if in profiler activation/deactivation steps
114+
# TODO(b/456828037): Switch to subprocess profiling to avoid timing artifacts at boundary steps.
115+
scalars = metrics["scalar"]
116+
loss = scalars["learning/loss"]
117+
is_rampup = step < self.config.rampup_end_step
118+
is_metric_hidden_step = self.config.hide_profiler_step_metric & self._is_profiler_boundary_step(step)
119+
120+
# Start building the log parts
121+
log_parts = []
122+
if is_rampup:
123+
log_parts.append("[Rampup Batch Size Phase]")
124+
125+
if is_metric_hidden_step:
126+
log_parts.append(
127+
f"completed profiler activation/deactivation step: {step}",
128+
)
130129
else:
131-
log_message = (
132-
f"eval metrics after step: {step},"
133-
f" loss={metrics['scalar']['eval/avg_loss']:.3f},"
134-
f" total_weights={metrics['scalar']['eval/total_weights']}"
130+
log_parts.extend(
131+
[
132+
f"completed step: {step}",
133+
f"seconds: {scalars['perf/step_time_seconds']:.3f}",
134+
]
135135
)
136136

137-
if self.config.mtp_num_layers > 0:
138-
log_message += (
139-
f", avg_mtp_loss={metrics['scalar']['eval/avg_mtp_loss']:.3f},"
140-
f" avg_mtp_acceptance_rate={metrics['scalar']['eval/avg_mtp_acceptance_rate_percent']:.2f}%"
141-
)
137+
# Add performance metrics only if strictly NOT in rampup phase
138+
# TODO(b/452468482): Enable performance metric (TFLOPs, Tokens/s) tracking during batch size rampup.
139+
if not is_rampup and not is_metric_hidden_step:
140+
log_parts.extend(
141+
[
142+
f"TFLOP/s/device: {scalars['perf/per_device_tflops_per_sec']:.3f}",
143+
f"Tokens/s/device: {scalars['perf/per_device_tokens_per_sec']:.3f}",
144+
]
145+
)
146+
147+
log_parts.extend(
148+
[
149+
f"total_weights: {scalars['learning/total_weights']}",
150+
f"loss: {loss:.3f}",
151+
]
152+
)
153+
154+
if self.config.mtp_num_layers > 0:
155+
mtp_loss = scalars.get("learning/mtp_loss", 0.0)
156+
log_parts.append(f"main_model_loss: {loss - mtp_loss:.3f}")
157+
log_parts.append(f"mtp_loss: {mtp_loss:.3f}")
158+
159+
max_logging.log(", ".join(log_parts))
160+
161+
def _log_eval_metrics(self, metrics, step):
162+
"""Handles evaluation-specific metric logging."""
163+
scalars = metrics["scalar"]
164+
log_parts = [
165+
f"eval metrics after step: {step}",
166+
f"loss={scalars['eval/avg_loss']:.3f}",
167+
f"total_weights={scalars['eval/total_weights']}",
168+
]
169+
170+
if self.config.mtp_num_layers > 0:
171+
log_parts.extend(
172+
[
173+
f"avg_mtp_loss={scalars['eval/avg_mtp_loss']:.3f}",
174+
f"avg_mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%",
175+
]
176+
)
142177

143-
max_logging.log(log_message)
178+
max_logging.log(", ".join(log_parts))
179+
180+
def _is_profiler_boundary_step(self, step):
181+
"""Determines if the current step is a profiler start/stop boundary that should be hidden."""
182+
if len(self.config.profiler) == 0:
183+
return False
184+
skip_steps = self.config.skip_first_n_steps_for_profiler
185+
profiler_steps = self.config.profiler_steps
186+
# Steps immediately before/at start, and at/immediately after end of profiling
187+
boundary_steps = {
188+
skip_steps,
189+
skip_steps + 1,
190+
skip_steps + profiler_steps,
191+
skip_steps + profiler_steps + 1,
192+
}
193+
return step in boundary_steps
144194

145195
def write_metrics_locally(self, metrics, step):
146196
"""Writes metrics locally for testing."""

0 commit comments

Comments
 (0)