Skip to content

Commit ef01565

Browse files
Extend pipeline in_progress checks
1 parent 179d179 commit ef01565

File tree

7 files changed

+141
-32
lines changed

7 files changed

+141
-32
lines changed

src/zenml/models/v2/core/step_run.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ class StepRunResponseBody(ProjectScopedResponseBody):
239239
title="The substitutions of the step run.",
240240
default={},
241241
)
242+
cached_heartbeat_threshold: Optional[int] = Field(title="", default=None)
242243
model_config = ConfigDict(protected_namespaces=())
243244

244245

@@ -609,6 +610,15 @@ def latest_heartbeat(self) -> Optional[datetime]:
609610
"""
610611
return self.get_body().latest_heartbeat
611612

613+
@property
614+
def cached_heartbeat_threshold(self) -> Optional[int]:
615+
"""The `cached_heartbeat_threshold` property.
616+
617+
Returns:
618+
the value of the property.
619+
"""
620+
return self.get_body().cached_heartbeat_threshold
621+
612622
@property
613623
def logs(self) -> Optional["LogsResponse"]:
614624
"""The `logs` property.

src/zenml/steps/heartbeat.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from uuid import UUID
2323

2424
from zenml.enums import ExecutionStatus
25+
from zenml.utils.time_utils import to_local_tz
2526

2627
if TYPE_CHECKING:
2728
from zenml.models import StepRunResponse
@@ -173,36 +174,64 @@ def _heartbeat(self) -> None:
173174
self._terminated = True
174175

175176

176-
def is_heartbeat_unhealthy(step_run: "StepRunResponse") -> bool:
177+
def cached_is_heartbeat_unhealthy(
178+
step_run_id: UUID,
179+
status: ExecutionStatus,
180+
latest_heartbeat: datetime | None,
181+
start_time: datetime | None = None,
182+
heartbeat_threshold: int | None = None,
183+
) -> bool:
177184
"""Utility function - Checks if step heartbeats indicate un-healthy execution.
178185
179186
Args:
180-
step_run: Information regarding a step run.
187+
step_run_id: The run step id.
188+
status: The run step status.
189+
latest_heartbeat: The run step latest heartbeat.
190+
start_time: The run step start time.
191+
heartbeat_threshold: If heartbeat enabled the max minutes without heartbeat
192+
for healthy, running tasks.
181193
182194
Returns:
183195
True if the step heartbeat is unhealthy, False otherwise.
184196
"""
185-
if not step_run.spec.enable_heartbeat:
197+
if not heartbeat_threshold:
186198
return False
187199

188-
if step_run.status.is_finished:
200+
if status.is_finished:
189201
return False
190202

191-
if step_run.latest_heartbeat:
192-
heartbeat_diff = (
193-
datetime.now(tz=timezone.utc) - step_run.latest_heartbeat
203+
if latest_heartbeat:
204+
heartbeat_diff = datetime.now(tz=timezone.utc) - to_local_tz(
205+
latest_heartbeat
206+
)
207+
elif start_time:
208+
heartbeat_diff = datetime.now(tz=timezone.utc) - to_local_tz(
209+
start_time
194210
)
195-
elif step_run.start_time:
196-
heartbeat_diff = datetime.now(tz=timezone.utc) - step_run.start_time
197211
else:
198212
return False
199213

200-
logger.info("%s heartbeat diff=%s", step_run.name, heartbeat_diff)
214+
logger.debug("Step %s heartbeat diff=%s", step_run_id, heartbeat_diff)
201215

202-
if (
203-
heartbeat_diff.total_seconds()
204-
> step_run.config.heartbeat_healthy_threshold * 60
205-
):
216+
if heartbeat_diff.total_seconds() > heartbeat_threshold * 60:
206217
return True
207218

208219
return False
220+
221+
222+
def is_heartbeat_unhealthy(step_run: "StepRunResponse") -> bool:
223+
"""Utility function - Checks if step heartbeats indicate un-healthy execution.
224+
225+
Args:
226+
step_run: Information regarding a step run.
227+
228+
Returns:
229+
True if the step heartbeat is unhealthy, False otherwise.
230+
"""
231+
return cached_is_heartbeat_unhealthy(
232+
step_run_id=step_run.id,
233+
status=step_run.status,
234+
start_time=step_run.start_time,
235+
heartbeat_threshold=step_run.cached_heartbeat_threshold,
236+
latest_heartbeat=step_run.latest_heartbeat,
237+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Add cached heartbeat threshold column [2837a7ab3d77].
2+
3+
Revision ID: 2837a7ab3d77
4+
Revises: d203788f82b9
5+
Create Date: 2025-11-28 12:27:14.341553
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "2837a7ab3d77"
14+
down_revision = "d203788f82b9"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
"""Upgrade database schema and/or data, creating a new revision."""
21+
with op.batch_alter_table("step_run", schema=None) as batch_op:
22+
batch_op.add_column(
23+
sa.Column(
24+
"cached_heartbeat_threshold", sa.Integer(), nullable=True
25+
)
26+
)
27+
28+
29+
def downgrade() -> None:
30+
"""Downgrade database schema and/or data back to the previous revision."""
31+
with op.batch_alter_table("step_run", schema=None) as batch_op:
32+
batch_op.drop_column("cached_heartbeat_threshold")

src/zenml/zen_stores/schemas/pipeline_run_schemas.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pydantic import ConfigDict
2222
from sqlalchemy import UniqueConstraint
23-
from sqlalchemy.orm import object_session, selectinload
23+
from sqlalchemy.orm import Session, object_session, selectinload
2424
from sqlalchemy.sql.base import ExecutableOption
2525
from sqlmodel import TEXT, Column, Field, Relationship, select
2626

@@ -815,6 +815,30 @@ def is_placeholder_run(self) -> bool:
815815
ExecutionStatus.PROVISIONING.value,
816816
}
817817

818+
@staticmethod
819+
def _step_in_progress(step_id: UUID, session: Session) -> bool:
820+
from zenml.steps.heartbeat import cached_is_heartbeat_unhealthy
821+
822+
step_run = session.execute(
823+
select(StepRunSchema).where(StepRunSchema.id == step_id)
824+
).scalar_one_or_none()
825+
826+
if not step_run:
827+
return False
828+
829+
status: ExecutionStatus = ExecutionStatus(step_run.status)
830+
831+
return not (
832+
status.is_finished
833+
or cached_is_heartbeat_unhealthy(
834+
step_run_id=step_run.id,
835+
status=status,
836+
heartbeat_threshold=step_run.cached_heartbeat_threshold,
837+
start_time=step_run.start_time,
838+
latest_heartbeat=step_run.latest_heartbeat,
839+
)
840+
)
841+
818842
def _check_if_run_in_progress(self) -> bool:
819843
"""Checks whether the run is in progress.
820844
@@ -843,19 +867,25 @@ def _check_if_run_in_progress(self) -> bool:
843867

844868
if session := object_session(self):
845869
step_run_statuses = session.execute(
846-
select(StepRunSchema.name, StepRunSchema.status).where(
847-
StepRunSchema.pipeline_run_id == self.id
848-
)
870+
select(
871+
StepRunSchema.id,
872+
StepRunSchema.name,
873+
StepRunSchema.status,
874+
).where(StepRunSchema.pipeline_run_id == self.id)
849875
).all()
850876

851877
if self.snapshot and self.snapshot.pipeline_spec:
852878
step_dict = self.get_upstream_steps()
853879

854880
dag = build_dag(step_dict)
855881

882+
step_name_to_id = {
883+
name: id_ for id_, name, _ in step_run_statuses
884+
}
885+
856886
failed_steps = {
857887
name
858-
for name, status in step_run_statuses
888+
for _, name, status in step_run_statuses
859889
if ExecutionStatus(status).is_failed
860890
}
861891

@@ -874,12 +904,21 @@ def _check_if_run_in_progress(self) -> bool:
874904

875905
for step_name, _ in step_dict.items():
876906
if step_name in steps_to_skip:
907+
# failed steps downstream
908+
continue
909+
910+
if steps_statuses[step_name].is_finished:
911+
# completed steps
877912
continue
878913

879914
if step_name not in steps_statuses:
915+
# steps that haven't started yet
880916
return True
881917

882-
elif not steps_statuses[step_name].is_finished:
918+
if self._step_in_progress(
919+
step_name_to_id[step_name], session=session
920+
):
921+
# running steps without unhealthy heartbeats
883922
return True
884923

885924
return False

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
117117
nullable=True,
118118
)
119119
)
120-
120+
cached_heartbeat_threshold: Optional[int] = Field(nullable=True)
121121
# Foreign keys
122122
original_step_run_id: Optional[UUID] = build_foreign_key_field(
123123
source=__tablename__,
@@ -427,6 +427,7 @@ def to_model(
427427
start_time=self.start_time,
428428
end_time=self.end_time,
429429
latest_heartbeat=self.latest_heartbeat,
430+
cached_heartbeat_threshold=self.cached_heartbeat_threshold,
430431
created=self.created,
431432
updated=self.updated,
432433
model_version_id=self.model_version_id,

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10095,6 +10095,13 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
1009510095
is_retriable=is_retriable,
1009610096
)
1009710097

10098+
# cached top-level heartbeat config property (for fast validation).
10099+
step_schema.cached_heartbeat_threshold = (
10100+
step_config.config.heartbeat_healthy_threshold
10101+
if step_config.spec.enable_heartbeat
10102+
else None
10103+
)
10104+
1009810105
session.add(step_schema)
1009910106
try:
1010010107
session.commit()

tests/unit/steps/test_heartbeat_worker.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def test_heartbeat_healthiness_check(monkeypatch):
140140
is_retriable=False,
141141
start_time=None,
142142
latest_heartbeat=None,
143+
cached_heartbeat_threshold=5,
143144
),
144145
metadata=StepRunResponseMetadata(
145146
config=StepConfiguration(
@@ -189,15 +190,5 @@ def now(cls, tz=None):
189190
assert not heartbeat.is_heartbeat_unhealthy(step_run)
190191

191192
# if step heartbeat not enabled = healthy (default response)
192-
step_run.metadata.config = StepConfiguration(
193-
name="test", heartbeat_healthy_threshold=5
194-
)
195-
assert heartbeat.is_heartbeat_unhealthy(step_run)
196-
step_run.metadata.spec = StepSpec(
197-
enable_heartbeat=False,
198-
source=Source(module="test", type=SourceType.BUILTIN),
199-
upstream_steps=["test"],
200-
inputs={"test": InputSpec(step_name="test", output_name="test")},
201-
invocation_id="test",
202-
)
193+
step_run.body.cached_heartbeat_threshold = None
203194
assert not heartbeat.is_heartbeat_unhealthy(step_run)

0 commit comments

Comments
 (0)