Skip to content

Commit 62b249a

Browse files
Merge pull request #2687 from AI-Hypercomputer:fix-goodput-v15-integration
PiperOrigin-RevId: 832608784
2 parents 81d653c + 4fc65f9 commit 62b249a

File tree

7 files changed

+43
-30
lines changed

7 files changed

+43
-30
lines changed

src/MaxText/elastic_train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,7 @@ def main(argv: Sequence[str]) -> None:
384384
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
385385
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
386386

387-
# Goodput configurations
388-
maybe_monitor_goodput(config)
387+
# Create the Goodput recorder
389388
recorder = create_goodput_recorder(config)
390389

391390
# Stack traces configurations
@@ -399,7 +398,7 @@ def main(argv: Sequence[str]) -> None:
399398
diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config)
400399

401400
with diagnostic.diagnose(diagnostic_config):
402-
with maybe_record_goodput(recorder, GoodputEvent.JOB):
401+
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
403402
train_loop(config, elastic_manager, recorder)
404403

405404

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -946,8 +946,7 @@ def main(argv: Sequence[str]) -> None:
946946
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
947947
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
948948

949-
# Goodput configurations
950-
maybe_monitor_goodput(config)
949+
# Create the Goodput recorder
951950
recorder = create_goodput_recorder(config)
952951

953952
# Stack traces configurations
@@ -961,7 +960,7 @@ def main(argv: Sequence[str]) -> None:
961960
diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config)
962961

963962
with diagnostic.diagnose(diagnostic_config):
964-
with maybe_record_goodput(recorder, GoodputEvent.JOB):
963+
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
965964
train_loop(config, config_inference, recorder)
966965

967966

src/MaxText/sft/sft_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,9 @@ def main(argv: Sequence[str]) -> None:
194194
mt_config = pyconfig.initialize(argv)
195195
max_utils.print_system_information()
196196

197-
maybe_monitor_goodput(mt_config)
198197
goodput_recorder = create_goodput_recorder(mt_config)
199198

200-
with maybe_record_goodput(goodput_recorder, GoodputEvent.JOB):
199+
with maybe_record_goodput(goodput_recorder, GoodputEvent.JOB), maybe_monitor_goodput(mt_config):
201200
train(mt_config, goodput_recorder)
202201

203202

src/MaxText/sft_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,8 @@ def main(argv: Sequence[str]) -> None:
168168
validate_train_config(config)
169169
os.environ["TFDS_DATA_DIR"] = config.dataset_path
170170

171-
maybe_monitor_goodput(config)
172171
recorder = create_goodput_recorder(config)
173-
with maybe_record_goodput(recorder, GoodputEvent.JOB):
172+
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
174173
train_loop(config, recorder)
175174

176175

src/MaxText/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,7 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
502502
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
503503
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
504504

505-
# Goodput configurations
506-
maybe_monitor_goodput(config)
505+
# Create the Goodput recorder
507506
recorder = create_goodput_recorder(config)
508507

509508
# Stack traces configurations
@@ -524,6 +523,7 @@ def run(config, recorder, diagnostic_config):
524523
diagnostic.diagnose(diagnostic_config),
525524
maybe_record_goodput(recorder, GoodputEvent.JOB),
526525
max_utils.maybe_get_transformer_engine_context(config),
526+
maybe_monitor_goodput(config),
527527
):
528528
train_loop(config, recorder)
529529

src/MaxText/utils/goodput_utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ class GoodputEvent(Enum):
3434
STEP = "step"
3535

3636

37+
@contextlib.contextmanager
3738
def maybe_monitor_goodput(config):
38-
"""Monitor goodput if `monitor_goodput=True`."""
39-
if config.monitor_goodput and jax.process_index() == 0:
40-
# Workload monitoring and Goodput monitoring both uses /workload/performance
41-
# GCM metric to publish step_time and step_deviation metrics. For now, we
42-
# will disable publishing step deviation metrics to GCM if workload
43-
# monitoring is enabled. Will reconcile this in the future.
39+
"""Monitor cumulative goodput if enabled."""
40+
if not config.monitor_goodput or jax.process_index() != 0:
41+
yield
42+
return
43+
goodput_monitor = None
44+
try:
4445
if config.report_performance_metric_for_gcp_monitoring:
4546
config.enable_gcp_step_deviation_metrics = False
4647

@@ -62,10 +63,11 @@ def maybe_monitor_goodput(config):
6263
)
6364
goodput_monitor.start_goodput_uploader()
6465
max_logging.log("Started Goodput upload to Tensorboard & GCM in the background!")
65-
66-
if config.monitor_step_time_deviation:
67-
goodput_monitor.start_step_deviation_uploader()
68-
max_logging.log("Started step time deviation upload to Tensorboard & GCM in the background!")
66+
yield
67+
finally:
68+
if goodput_monitor:
69+
goodput_monitor.stop_goodput_uploader()
70+
max_logging.log("Flushed final metrics and safe exited from Goodput monitoring.")
6971

7072

7173
@contextlib.contextmanager
@@ -75,9 +77,13 @@ def maybe_record_goodput(recorder, event_name, *args):
7577
start_event_name = f"record_{event_name.value}_start_time"
7678
record_goodput(recorder, start_event_name, *args)
7779
yield
78-
finally:
80+
except BaseException: # pylint: disable=W0706
81+
raise
82+
else:
7983
end_event_name = f"record_{event_name.value}_end_time"
8084
record_goodput(recorder, end_event_name, *args)
85+
finally:
86+
pass
8187

8288

8389
def record_goodput(recorder, event_name, *args):

tests/goodput_utils_test.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,27 @@ def test_record_goodput(self, mock_cloud_logger, mock_record_job_start_time, moc
5353
mock_record_job_start_time.assert_called()
5454
mock_record_job_end_time.assert_called()
5555

56-
@mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor.start_step_deviation_uploader")
56+
class TestException(BaseException):
57+
pass
58+
59+
mock_record_job_start_time.reset_mock()
60+
mock_record_job_end_time.reset_mock()
61+
with self.assertRaises(TestException):
62+
with maybe_record_goodput(recorder, GoodputEvent.JOB):
63+
mock_record_job_start_time.assert_called_once()
64+
raise TestException()
65+
66+
mock_record_job_start_time.assert_called_once()
67+
mock_record_job_end_time.assert_not_called()
68+
69+
@mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor.stop_goodput_uploader")
5770
@mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor.start_goodput_uploader")
58-
def test_monitor_goodput(self, mock_start_goodput_uploader, mock_start_step_deviation_uploader):
71+
def test_monitor_goodput(self, mock_start_goodput_uploader, mock_stop_goodput_uploader):
5972
mock_start_goodput_uploader.return_value = mock.MagicMock()
60-
mock_start_step_deviation_uploader.return_value = mock.MagicMock()
61-
62-
maybe_monitor_goodput(self.config)
6373

64-
mock_start_goodput_uploader.assert_called()
65-
mock_start_step_deviation_uploader.assert_called()
74+
with maybe_monitor_goodput(self.config):
75+
mock_start_goodput_uploader.assert_called()
76+
mock_stop_goodput_uploader.assert_called()
6677

6778

6879
if __name__ == "__main__":

0 commit comments

Comments
 (0)