Skip to content

Commit d817e63

Browse files
Feature:4135 Handle heartbeat interrupts properly
1 parent 5aba3ed commit d817e63

File tree

3 files changed

+75
-39
lines changed

3 files changed

+75
-39
lines changed

src/zenml/orchestrators/publish_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,23 @@ def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
118118
)
119119

120120

121+
def publish_stopped_step_run(step_run_id: "UUID") -> "StepRunResponse":
122+
"""Publishes a stopped step run.
123+
124+
Args:
125+
step_run_id: The ID of the step run to update.
126+
127+
Returns:
128+
The updated step run.
129+
"""
130+
return publish_step_run_status_update(
131+
step_run_id=step_run_id,
132+
status=ExecutionStatus.STOPPED,
133+
end_time=utc_now(),
134+
exception_info=step_exception_info.get(),
135+
)
136+
137+
121138
def publish_successful_pipeline_run(
122139
pipeline_run_id: "UUID",
123140
) -> "PipelineRunResponse":

src/zenml/orchestrators/step_launcher.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from zenml.orchestrators import utils as orchestrator_utils
4444
from zenml.orchestrators.step_runner import StepRunner
4545
from zenml.stack import Stack
46+
from zenml.steps import StepHeartBeatTerminationException
4647
from zenml.utils import env_utils, exception_utils, string_utils
4748
from zenml.utils.time_utils import utc_now
4849

@@ -361,13 +362,22 @@ def _bypass() -> None:
361362
except RunStoppedException as e:
362363
raise e
363364
except BaseException as e: # noqa: E722
364-
logger.error(
365-
"Failed to run step `%s`: %s",
366-
self._invocation_id,
367-
e,
368-
)
369-
publish_utils.publish_failed_step_run(step_run.id)
370-
raise
365+
step_run = Client().get_run_step(step_run_id=step_run.id)
366+
367+
if (
368+
isinstance(e, StepHeartBeatTerminationException)
369+
or step_run.status == ExecutionStatus.STOPPING
370+
):
371+
# Handle as a non-failure as exception is a propagation of graceful termination.
372+
publish_utils.publish_stopped_step_run(step_run.id)
373+
else:
374+
logger.error(
375+
"Failed to run step `%s`: %s",
376+
self._invocation_id,
377+
e,
378+
)
379+
publish_utils.publish_failed_step_run(step_run.id)
380+
raise
371381
else:
372382
logger.info(
373383
f"Using cached version of step `{self._invocation_id}`."

src/zenml/orchestrators/step_runner.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
ENV_ZENML_STEP_OPERATOR,
3838
handle_bool_env_var,
3939
)
40-
from zenml.enums import ArtifactSaveType
40+
from zenml.enums import ArtifactSaveType, ExecutionStatus
4141
from zenml.exceptions import StepInterfaceError
4242
from zenml.hooks.hook_validators import load_and_run_hook
4343
from zenml.logger import get_logger
@@ -251,46 +251,55 @@ def run(
251251
)
252252
except BaseException as step_exception: # noqa: E722
253253
step_failed = True
254-
255-
exception_info = (
256-
exception_utils.collect_exception_information(
257-
step_exception, step_instance
258-
)
259-
)
260-
261-
if ENV_ZENML_STEP_OPERATOR in os.environ:
262-
# We're running in a step operator environment, so we can't
263-
# depend on the step launcher to publish the exception info
254+
if (
255+
isinstance(step_exception, KeyboardInterrupt)
256+
and heartbeat_worker.is_terminated
257+
):
264258
Client().zen_store.update_run_step(
265259
step_run_id=step_run_info.step_run_id,
266260
step_run_update=StepRunUpdate(
267-
exception_info=exception_info,
261+
status=ExecutionStatus.STOPPING,
268262
),
269263
)
270-
else:
271-
# This will be published by the step launcher
272-
step_exception_info.set(exception_info)
273264

274-
if not step_run.is_retriable:
275-
if (
276-
failure_hook_source
277-
:= self.configuration.failure_hook_source
278-
):
279-
logger.info("Detected failure hook. Running...")
280-
with env_utils.temporary_environment(
281-
step_environment
282-
):
283-
load_and_run_hook(
284-
failure_hook_source,
285-
step_exception=step_exception,
286-
)
287-
if (
288-
isinstance(step_exception, KeyboardInterrupt)
289-
and heartbeat_worker.is_terminated
290-
):
291265
raise StepHeartBeatTerminationException(
292266
"Remotely stopped step - terminating execution."
293267
)
268+
else:
269+
exception_info = (
270+
exception_utils.collect_exception_information(
271+
step_exception, step_instance
272+
)
273+
)
274+
275+
if ENV_ZENML_STEP_OPERATOR in os.environ:
276+
# We're running in a step operator environment, so we can't
277+
# depend on the step launcher to publish the exception info
278+
Client().zen_store.update_run_step(
279+
step_run_id=step_run_info.step_run_id,
280+
step_run_update=StepRunUpdate(
281+
exception_info=exception_info,
282+
),
283+
)
284+
else:
285+
# This will be published by the step launcher
286+
step_exception_info.set(exception_info)
287+
288+
if not step_run.is_retriable:
289+
if (
290+
failure_hook_source
291+
:= self.configuration.failure_hook_source
292+
):
293+
logger.info(
294+
"Detected failure hook. Running..."
295+
)
296+
with env_utils.temporary_environment(
297+
step_environment
298+
):
299+
load_and_run_hook(
300+
failure_hook_source,
301+
step_exception=step_exception,
302+
)
294303
raise step_exception
295304
finally:
296305
heartbeat_worker.stop()

0 commit comments

Comments
 (0)