Skip to content

Commit 1ba9aa5

Browse files
committed
temp remove profiler (de)activation logging
1 parent 188c426 commit 1ba9aa5

File tree

1 file changed

+75
-35
lines changed

1 file changed

+75
-35
lines changed

src/MaxText/metric_logger.py

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -104,43 +104,83 @@ def write_metrics(self, metrics, step, is_training=True):
104104
def log_metrics(self, metrics, step, is_training):
105105
"""Logs metrics via max_logging."""
106106
if is_training:
107-
loss = metrics["scalar"]["learning/loss"]
108-
# Do not show flops and tokens during batch size rampup
109-
if step >= self.config.rampup_end_step:
110-
log_message = (
111-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
112-
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
113-
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
114-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
115-
f"loss: {loss:.3f}"
116-
)
117-
else:
118-
log_message = (
119-
"[Rampup Batch Size Phase]: "
120-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
121-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
122-
f"loss: {loss:.3f}"
123-
)
124-
125-
if self.config.mtp_num_layers > 0:
126-
mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0)
127-
main_model_loss = loss - mtp_loss
128-
log_message += f", main_model_loss: {main_model_loss:.3f}, mtp_loss: {mtp_loss:.3f}"
129-
107+
self._log_training_metrics(metrics, step)
130108
else:
131-
log_message = (
132-
f"eval metrics after step: {step},"
133-
f" loss={metrics['scalar']['eval/avg_loss']:.3f},"
134-
f" total_weights={metrics['scalar']['eval/total_weights']}"
109+
self._log_eval_metrics(metrics, step)
110+
111+
def _log_training_metrics(self, metrics, step):
112+
"""Handles training-specific metric logging."""
113+
# Skip logging if in profiler activation/deactivation steps
114+
# TODO(b/456828037): Switch to subprocess profiling to avoid timing artifacts at boundary steps.
115+
scalars = metrics["scalar"]
116+
loss = scalars["learning/loss"]
117+
is_rampup = step < self.config.rampup_end_step
118+
is_profiler_boundary = self._is_profiler_boundary_step(step)
119+
120+
# Start building the log parts
121+
log_parts = []
122+
if is_rampup:
123+
log_parts.append("[Rampup Batch Size Phase]")
124+
125+
if is_profiler_boundary:
126+
log_parts.append(
127+
f"completed profiler activation/deactivation step: {step}",
135128
)
136-
137-
if self.config.mtp_num_layers > 0:
138-
log_message += (
139-
f", avg_mtp_loss={metrics['scalar']['eval/avg_mtp_loss']:.3f},"
140-
f" avg_mtp_acceptance_rate={metrics['scalar']['eval/avg_mtp_acceptance_rate_percent']:.2f}%"
141-
)
142-
143-
max_logging.log(log_message)
129+
else:
130+
log_parts.extend([
131+
f"completed step: {step}",
132+
f"seconds: {scalars['perf/step_time_seconds']:.3f}",
133+
])
134+
135+
# Add performance metrics only if strictly NOT in rampup phase
136+
# TODO(b/452468482): Enable performance metric (TFLOPs, Tokens/s) tracking during batch size rampup.
137+
if not is_rampup and not is_profiler_boundary:
138+
log_parts.extend([
139+
f"TFLOP/s/device: {scalars['perf/per_device_tflops_per_sec']:.3f}",
140+
f"Tokens/s/device: {scalars['perf/per_device_tokens_per_sec']:.3f}",
141+
])
142+
143+
log_parts.extend([
144+
f"total_weights: {scalars['learning/total_weights']}",
145+
f"loss: {loss:.3f}",
146+
])
147+
148+
if self.config.mtp_num_layers > 0:
149+
mtp_loss = scalars.get("learning/mtp_loss", 0.0)
150+
log_parts.append(f"main_model_loss: {loss - mtp_loss:.3f}")
151+
log_parts.append(f"mtp_loss: {mtp_loss:.3f}")
152+
153+
max_logging.log(", ".join(log_parts))
154+
155+
def _log_eval_metrics(self, metrics, step):
156+
"""Handles evaluation-specific metric logging."""
157+
scalars = metrics["scalar"]
158+
log_parts = [
159+
f"eval metrics after step: {step}",
160+
f"loss={scalars['eval/avg_loss']:.3f}",
161+
f"total_weights={scalars['eval/total_weights']}",
162+
]
163+
164+
if self.config.mtp_num_layers > 0:
165+
log_parts.extend([
166+
f"avg_mtp_loss={scalars['eval/avg_mtp_loss']:.3f}",
167+
f"avg_mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%",
168+
])
169+
170+
max_logging.log(", ".join(log_parts))
171+
172+
def _is_profiler_boundary_step(self, step):
173+
"""Determines if the current step is a profiler start/stop boundary that should be hidden."""
174+
skip_steps = self.config.skip_first_n_steps_for_profiler
175+
profiler_steps = self.config.profiler_steps
176+
# Steps immediately before/at start, and at/immediately after end of profiling
177+
boundary_steps = {
178+
skip_steps,
179+
skip_steps + 1,
180+
skip_steps + profiler_steps,
181+
skip_steps + profiler_steps + 1,
182+
}
183+
return step in boundary_steps
144184

145185
def write_metrics_locally(self, metrics, step):
146186
"""Writes metrics locally for testing."""

0 commit comments

Comments
 (0)