@@ -34,13 +34,14 @@ class GoodputEvent(Enum):
3434 STEP = "step"
3535
3636
37+ @contextlib .contextmanager
3738def maybe_monitor_goodput (config ):
38- """Monitor goodput if `monitor_goodput=True` ."""
39- if config .monitor_goodput and jax .process_index () = = 0 :
40- # Workload monitoring and Goodput monitoring both uses /workload/performance
41- # GCM metric to publish step_time and step_deviation metrics. For now, we
42- # will disable publishing step deviation metrics to GCM if workload
43- # monitoring is enabled. Will reconcile this in the future.
39+ """Monitor cumulative goodput if enabled ."""
40+ if not config .monitor_goodput or jax .process_index () ! = 0 :
41+ yield
42+ return
43+ goodput_monitor = None
44+ try :
4445 if config .report_performance_metric_for_gcp_monitoring :
4546 config .enable_gcp_step_deviation_metrics = False
4647
@@ -62,10 +63,11 @@ def maybe_monitor_goodput(config):
6263 )
6364 goodput_monitor .start_goodput_uploader ()
6465 max_logging .log ("Started Goodput upload to Tensorboard & GCM in the background!" )
65-
66- if config .monitor_step_time_deviation :
67- goodput_monitor .start_step_deviation_uploader ()
68- max_logging .log ("Started step time deviation upload to Tensorboard & GCM in the background!" )
66+ yield
67+ finally :
68+ if goodput_monitor :
69+ goodput_monitor .stop_goodput_uploader ()
70+ max_logging .log ("Flushed final metrics and safe exited from Goodput monitoring." )
6971
7072
7173@contextlib .contextmanager
@@ -75,9 +77,13 @@ def maybe_record_goodput(recorder, event_name, *args):
7577 start_event_name = f"record_{ event_name .value } _start_time"
7678 record_goodput (recorder , start_event_name , * args )
7779 yield
78- finally :
80+ except BaseException : # pylint: disable=W0706
81+ raise
82+ else :
7983 end_event_name = f"record_{ event_name .value } _end_time"
8084 record_goodput (recorder , end_event_name , * args )
85+ finally :
86+ pass
8187
8288
8389def record_goodput (recorder , event_name , * args ):
0 commit comments