Skip to content

Commit 5aba3ed

Browse files
Feature:4135 Orchestrator generic graceful stopping
- Early stopping on unhealthy heartbeat - Use pipeline status for heartbeat - Some extra utils and improvements
1 parent adf775f commit 5aba3ed

File tree

9 files changed

+126
-35
lines changed

9 files changed

+126
-35
lines changed

src/zenml/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4947,7 +4947,7 @@ def list_run_steps(
49474947
cache_expired: Whether the cache expiration time of the step run
49484948
has passed.
49494949
code_hash: The code hash of the step run to filter by.
4950-
status: The name of the run to filter by.
4950+
status: The status of the step run.
49514951
run_metadata: Filter by run metadata.
49524952
exclude_retried: Whether to exclude retried step runs.
49534953
hydrate: Flag deciding whether to hydrate the output model(s)

src/zenml/config/step_configurations.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,13 @@ class StepConfigurationUpdate(FrozenBaseModel):
221221
"run inline unless a step operator or docker/resource settings "
222222
"are configured. This is only applicable for dynamic pipelines.",
223223
)
224+
heartbeat_healthy_threshold: int | None = Field(
225+
default=None,
226+
description="The amount of time (in minutes) that a running step "
227+
"has not received heartbeat and is considered healthy. Set null value"
228+
"disable healthiness checks via heartbeat.",
229+
ge=1,
230+
)
224231

225232
outputs: Mapping[str, PartialArtifactConfiguration] = {}
226233

src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,20 @@ def stop_step(node: Node) -> None:
601601
)
602602
break
603603

604+
def is_node_heartbeat_unhealthy(node: Node) -> bool:
605+
from zenml.steps.heartbeat import is_heartbeat_unhealthy
606+
607+
sr_ = client.list_run_steps(
608+
name=node.id, pipeline_run_id=pipeline_run.id
609+
)
610+
611+
if sr_.items:
612+
sr_ = sr_[0]
613+
614+
return is_heartbeat_unhealthy(step_run=sr_)
615+
616+
return False
617+
604618
def check_job_status(node: Node) -> NodeStatus:
605619
"""Check the status of a job.
606620
@@ -641,6 +655,12 @@ def check_job_status(node: Node) -> NodeStatus:
641655
error_message,
642656
)
643657
return NodeStatus.FAILED
658+
elif is_node_heartbeat_unhealthy(node):
659+
logger.error(
660+
"Heartbeat for step `%s` indicates unhealthy status.",
661+
step_name,
662+
)
663+
return NodeStatus.FAILED
644664
else:
645665
return NodeStatus.RUNNING
646666

src/zenml/models/v2/core/step_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,3 +823,4 @@ class StepHeartbeatResponse(BaseModel, use_enum_values=True):
823823
id: UUID
824824
status: ExecutionStatus
825825
latest_heartbeat: datetime
826+
pipeline_run_status: ExecutionStatus | None = None

src/zenml/orchestrators/base_orchestrator.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -678,27 +678,16 @@ def stop_run(
678678
If False, forces immediate termination. Default is False.
679679
680680
Raises:
681-
NotImplementedError: If any orchestrator inheriting from the base
682-
class does not implement this logic.
683681
IllegalOperationError: If the run has no orchestrator run id yet.
684682
"""
685-
# Check if the orchestrator supports cancellation
686-
if (
687-
getattr(self._stop_run, "__func__", None)
688-
is BaseOrchestrator._stop_run
689-
):
690-
raise NotImplementedError(
691-
f"The '{self.__class__.__name__}' orchestrator does not "
692-
"support stopping pipeline runs."
693-
)
694-
695683
if not run.orchestrator_run_id:
696684
raise IllegalOperationError(
697685
"Cannot stop a pipeline run that has no orchestrator run id "
698686
"yet."
699687
)
700688

701689
# Update pipeline status to STOPPING before calling concrete implementation
690+
# Initiates graceful termination.
702691
publish_pipeline_run_status_update(
703692
pipeline_run_id=run.id,
704693
status=ExecutionStatus.STOPPING,
@@ -720,13 +709,24 @@ def _stop_run(
720709
run: A pipeline run response to stop (already updated to STOPPING status).
721710
graceful: If True, allows for graceful shutdown where possible.
722711
If False, forces immediate termination. Default is True.
723-
724-
Raises:
725-
NotImplementedError: If any orchestrator inheriting from the base
726-
class does not implement this logic.
727712
"""
713+
if graceful:
714+
# This should work out of the box for HeartBeat step termination.
715+
# Orchestrators should extend the functionality to cover other scenarios.
716+
self._stop_run_gracefully(pipeline_run=run)
717+
else:
718+
self._stop_run_forcefully(pipeline_run=run)
719+
720+
def _stop_run_gracefully(
721+
self, pipeline_run: "PipelineRunResponse"
722+
) -> None:
723+
pass
724+
725+
def _stop_run_forcefully(
726+
self, pipeline_run: "PipelineRunResponse"
727+
) -> None:
728728
raise NotImplementedError(
729-
"The stop run functionality is not implemented for the "
729+
"The forceful stop run functionality is not implemented for the "
730730
f"'{self.__class__.__name__}' orchestrator."
731731
)
732732

src/zenml/steps/heartbeat.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717
import logging
1818
import threading
1919
import time
20+
from datetime import datetime, timezone
21+
from typing import TYPE_CHECKING
2022
from uuid import UUID
2123

2224
from zenml.enums import ExecutionStatus
2325

26+
if TYPE_CHECKING:
27+
from zenml.models import StepRunResponse
28+
2429
logger = logging.getLogger(__name__)
2530

2631

@@ -161,8 +166,43 @@ def _heartbeat(self) -> None:
161166

162167
response = store.update_step_heartbeat(step_run_id=self.step_id)
163168

164-
if response.status in {
169+
if response.pipeline_run_status in {
165170
ExecutionStatus.STOPPED,
166171
ExecutionStatus.STOPPING,
167172
}:
168173
self._terminated = True
174+
175+
176+
def is_heartbeat_unhealthy(step_run: "StepRunResponse") -> bool:
177+
"""Utility function - Checks if step heartbeats indicate un-healthy execution.
178+
179+
Args:
180+
step_run: Information regarding a step run.
181+
182+
Returns:
183+
True if the step heartbeat is unhealthy, False otherwise.
184+
"""
185+
if not step_run.spec.enable_heartbeat:
186+
return False
187+
188+
if not step_run.config.heartbeat_healthy_threshold:
189+
return False
190+
191+
if step_run.status.is_finished:
192+
heartbeat_diff = step_run.end_time - (
193+
step_run.latest_heartbeat or step_run.start_time
194+
)
195+
else:
196+
heartbeat_diff = datetime.now(tz=timezone.utc) - (
197+
step_run.latest_heartbeat or step_run.start_time
198+
)
199+
200+
logger.info("%s heartbeat diff=%s", step_run.name, heartbeat_diff)
201+
202+
if (
203+
heartbeat_diff.total_seconds()
204+
> step_run.config.heartbeat_healthy_threshold * 60
205+
):
206+
return True
207+
208+
return False

src/zenml/utils/pagination_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Pagination utilities."""
1515

16-
from typing import Any, Callable, List, TypeVar
16+
from typing import Any, Callable, Generator, List, TypeVar
1717

1818
from zenml.models import BaseIdentifiedResponse, Page
1919

@@ -40,3 +40,28 @@ def depaginate(
4040
items += list(page.items)
4141

4242
return items
43+
44+
45+
def depaginate_stream(
46+
list_method: Callable[..., Page[AnyResponse]], **kwargs: Any
47+
) -> Generator[AnyResponse, None, None]:
48+
"""Depaginate the results from a client or store method that returns pages.
49+
50+
Args:
51+
list_method: The list method to depaginate.
52+
**kwargs: Arguments for the list method.
53+
54+
Yields:
55+
A generator of the corresponding Response Models.
56+
"""
57+
page = list_method(**kwargs)
58+
59+
for item in page.items:
60+
yield item
61+
62+
while page.index < page.total_pages:
63+
kwargs["page"] = page.index + 1
64+
page = list_method(**kwargs)
65+
66+
for item in page.items:
67+
yield item

src/zenml/zen_server/routers/steps_endpoints.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -226,44 +226,41 @@ def update_heartbeat(
226226
"""
227227
step = zen_store().get_run_step(step_run_id, hydrate=False)
228228

229-
# Avoid using status.is_finished as it invalidates useful statuses for heartbeat
230-
# such as STOPPED.
231-
if step.status.is_failed or step.status.is_successful:
229+
if step.is_finished:
232230
raise HTTPException(
233231
status_code=422,
234232
detail=f"Step {step.id} is finished - can not update heartbeat.",
235233
)
236234

237-
def validate_token_access(
238-
ctx: AuthContext, step_: StepRunResponse
239-
) -> None:
235+
pipeline_run = zen_store().get_run(step.pipeline_run_id, hydrate=False)
236+
237+
def validate_token_access(ctx: AuthContext) -> None:
240238
token_run_id = ctx.access_token.pipeline_run_id # type: ignore[union-attr]
241239
token_schedule_id = ctx.access_token.schedule_id # type: ignore[union-attr]
242240

243241
if token_run_id:
244-
if step_.pipeline_run_id != token_run_id:
242+
if step.pipeline_run_id != token_run_id:
245243
raise AuthorizationException(
246-
f"Authentication token provided is invalid for step: {step_.id}"
244+
f"Authentication token provided is invalid for step: {step.id}"
247245
)
248246
elif token_schedule_id:
249-
pipeline_run = zen_store().get_run(
250-
step_.pipeline_run_id, hydrate=False
251-
)
252-
253247
if not (
254248
pipeline_run.schedule
255249
and pipeline_run.schedule.id == token_schedule_id
256250
):
257251
raise AuthorizationException(
258-
f"Authentication token provided is invalid for step: {step_.id}"
252+
f"Authentication token provided is invalid for step: {step.id}"
259253
)
260254
else:
261255
# un-scoped token. Soon to-be-deprecated, we will ignore validation temporarily.
262256
pass
263257

264-
validate_token_access(ctx=auth_context, step_=step)
258+
validate_token_access(ctx=auth_context)
259+
260+
hb = zen_store().update_step_heartbeat(step_run_id=step_run_id)
261+
hb.pipeline_run_status = pipeline_run.status
265262

266-
return zen_store().update_step_heartbeat(step_run_id=step_run_id)
263+
return hb
267264

268265

269266
@router.get(

tests/integration/functional/steps/test_heartbeat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_heartbeat_rest_functionality():
8888

8989
assert hb_response.status == ExecutionStatus.RUNNING
9090
assert hb_response.latest_heartbeat is not None
91+
assert hb_response.pipeline_run_status == ExecutionStatus.RUNNING
9192

9293
assert (
9394
client.zen_store.get_run_step(step_run_id=step_run.id).latest_heartbeat

0 commit comments

Comments
 (0)