Skip to content

Commit 43b78fd

Browse files
Extend pipeline in_progress checks
1 parent b3a576a commit 43b78fd

File tree

7 files changed

+141
-32
lines changed

7 files changed

+141
-32
lines changed

src/zenml/models/v2/core/step_run.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ class StepRunResponseBody(ProjectScopedResponseBody):
239239
title="The substitutions of the step run.",
240240
default={},
241241
)
242+
cached_heartbeat_threshold: Optional[int] = Field(title="", default=None)
242243
model_config = ConfigDict(protected_namespaces=())
243244

244245

@@ -609,6 +610,15 @@ def latest_heartbeat(self) -> Optional[datetime]:
609610
"""
610611
return self.get_body().latest_heartbeat
611612

613+
@property
614+
def cached_heartbeat_threshold(self) -> Optional[int]:
615+
"""The `cached_heartbeat_threshold` property.
616+
617+
Returns:
618+
the value of the property.
619+
"""
620+
return self.get_body().cached_heartbeat_threshold
621+
612622
@property
613623
def logs(self) -> Optional["LogsResponse"]:
614624
"""The `logs` property.

src/zenml/steps/heartbeat.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from uuid import UUID
2323

2424
from zenml.enums import ExecutionStatus
25+
from zenml.utils.time_utils import to_local_tz
2526

2627
if TYPE_CHECKING:
2728
from zenml.models import StepRunResponse
@@ -173,36 +174,64 @@ def _heartbeat(self) -> None:
173174
self._terminated = True
174175

175176

176-
def is_heartbeat_unhealthy(step_run: "StepRunResponse") -> bool:
177+
def cached_is_heartbeat_unhealthy(
178+
step_run_id: UUID,
179+
status: ExecutionStatus,
180+
latest_heartbeat: datetime | None,
181+
start_time: datetime | None = None,
182+
heartbeat_threshold: int | None = None,
183+
) -> bool:
177184
"""Utility function - Checks if step heartbeats indicate un-healthy execution.
178185
179186
Args:
180-
step_run: Information regarding a step run.
187+
step_run_id: The run step id.
188+
status: The run step status.
189+
latest_heartbeat: The run step latest heartbeat.
190+
start_time: The run step start time.
191+
heartbeat_threshold: If heartbeat enabled the max minutes without heartbeat
192+
for healthy, running tasks.
181193
182194
Returns:
183195
True if the step heartbeat is unhealthy, False otherwise.
184196
"""
185-
if not step_run.spec.enable_heartbeat:
197+
if not heartbeat_threshold:
186198
return False
187199

188-
if step_run.status.is_finished:
200+
if status.is_finished:
189201
return False
190202

191-
if step_run.latest_heartbeat:
192-
heartbeat_diff = (
193-
datetime.now(tz=timezone.utc) - step_run.latest_heartbeat
203+
if latest_heartbeat:
204+
heartbeat_diff = datetime.now(tz=timezone.utc) - to_local_tz(
205+
latest_heartbeat
206+
)
207+
elif start_time:
208+
heartbeat_diff = datetime.now(tz=timezone.utc) - to_local_tz(
209+
start_time
194210
)
195-
elif step_run.start_time:
196-
heartbeat_diff = datetime.now(tz=timezone.utc) - step_run.start_time
197211
else:
198212
return False
199213

200-
logger.info("%s heartbeat diff=%s", step_run.name, heartbeat_diff)
214+
logger.info("%s heartbeat diff=%s", step_run_id, heartbeat_diff)
201215

202-
if (
203-
heartbeat_diff.total_seconds()
204-
> step_run.config.heartbeat_healthy_threshold * 60
205-
):
216+
if heartbeat_diff.total_seconds() > heartbeat_threshold * 60:
206217
return True
207218

208219
return False
220+
221+
222+
def is_heartbeat_unhealthy(step_run: "StepRunResponse") -> bool:
223+
"""Utility function - Checks if step heartbeats indicate un-healthy execution.
224+
225+
Args:
226+
step_run: Information regarding a step run.
227+
228+
Returns:
229+
True if the step heartbeat is unhealthy, False otherwise.
230+
"""
231+
return cached_is_heartbeat_unhealthy(
232+
step_run_id=step_run.id,
233+
status=step_run.status,
234+
start_time=step_run.start_time,
235+
heartbeat_threshold=step_run.cached_heartbeat_threshold,
236+
latest_heartbeat=step_run.latest_heartbeat,
237+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Add cached heartbeat threshold column [2837a7ab3d77].
2+
3+
Revision ID: 2837a7ab3d77
4+
Revises: d203788f82b9
5+
Create Date: 2025-11-28 12:27:14.341553
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "2837a7ab3d77"
14+
down_revision = "d203788f82b9"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
"""Upgrade database schema and/or data, creating a new revision."""
21+
with op.batch_alter_table("step_run", schema=None) as batch_op:
22+
batch_op.add_column(
23+
sa.Column(
24+
"cached_heartbeat_threshold", sa.Integer(), nullable=True
25+
)
26+
)
27+
28+
29+
def downgrade() -> None:
30+
"""Downgrade database schema and/or data back to the previous revision."""
31+
with op.batch_alter_table("step_run", schema=None) as batch_op:
32+
batch_op.drop_column("cached_heartbeat_threshold")

src/zenml/zen_stores/schemas/pipeline_run_schemas.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pydantic import ConfigDict
2222
from sqlalchemy import UniqueConstraint
23-
from sqlalchemy.orm import object_session, selectinload
23+
from sqlalchemy.orm import Session, object_session, selectinload
2424
from sqlalchemy.sql.base import ExecutableOption
2525
from sqlmodel import TEXT, Column, Field, Relationship, select
2626

@@ -802,6 +802,30 @@ def is_placeholder_run(self) -> bool:
802802
ExecutionStatus.PROVISIONING.value,
803803
}
804804

805+
@staticmethod
806+
def _step_in_progress(step_id: UUID, session: Session) -> bool:
807+
from zenml.steps.heartbeat import cached_is_heartbeat_unhealthy
808+
809+
step_run = session.execute(
810+
select(StepRunSchema).where(StepRunSchema.id == step_id)
811+
).scalar_one_or_none()
812+
813+
if not step_run:
814+
return False
815+
816+
status: ExecutionStatus = ExecutionStatus(step_run.status)
817+
818+
return not (
819+
status.is_finished
820+
or cached_is_heartbeat_unhealthy(
821+
step_run_id=step_run.id,
822+
status=status,
823+
heartbeat_threshold=step_run.cached_heartbeat_threshold,
824+
start_time=step_run.start_time,
825+
latest_heartbeat=step_run.latest_heartbeat,
826+
)
827+
)
828+
805829
def _check_if_run_in_progress(self) -> bool:
806830
"""Checks whether the run is in progress.
807831
@@ -830,19 +854,25 @@ def _check_if_run_in_progress(self) -> bool:
830854

831855
if session := object_session(self):
832856
step_run_statuses = session.execute(
833-
select(StepRunSchema.name, StepRunSchema.status).where(
834-
StepRunSchema.pipeline_run_id == self.id
835-
)
857+
select(
858+
StepRunSchema.id,
859+
StepRunSchema.name,
860+
StepRunSchema.status,
861+
).where(StepRunSchema.pipeline_run_id == self.id)
836862
).all()
837863

838864
if self.snapshot and self.snapshot.pipeline_spec:
839865
step_dict = self.get_upstream_steps()
840866

841867
dag = build_dag(step_dict)
842868

869+
step_name_to_id = {
870+
name: id_ for id_, name, _ in step_run_statuses
871+
}
872+
843873
failed_steps = {
844874
name
845-
for name, status in step_run_statuses
875+
for _, name, status in step_run_statuses
846876
if ExecutionStatus(status).is_failed
847877
}
848878

@@ -861,12 +891,21 @@ def _check_if_run_in_progress(self) -> bool:
861891

862892
for step_name, _ in step_dict.items():
863893
if step_name in steps_to_skip:
894+
# failed steps downstream
895+
continue
896+
897+
if steps_statuses[step_name].is_finished:
898+
# completed steps
864899
continue
865900

866901
if step_name not in steps_statuses:
902+
# steps that haven't started yet
867903
return True
868904

869-
elif not steps_statuses[step_name].is_finished:
905+
if self._step_in_progress(
906+
step_name_to_id[step_name], session=session
907+
):
908+
# running steps without unhealthy heartbeats
870909
return True
871910

872911
return False

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
117117
nullable=True,
118118
)
119119
)
120-
120+
cached_heartbeat_threshold: Optional[int] = Field(nullable=True)
121121
# Foreign keys
122122
original_step_run_id: Optional[UUID] = build_foreign_key_field(
123123
source=__tablename__,
@@ -427,6 +427,7 @@ def to_model(
427427
start_time=self.start_time,
428428
end_time=self.end_time,
429429
latest_heartbeat=self.latest_heartbeat,
430+
cached_heartbeat_threshold=self.cached_heartbeat_threshold,
430431
created=self.created,
431432
updated=self.updated,
432433
model_version_id=self.model_version_id,

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10095,6 +10095,13 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
1009510095
is_retriable=is_retriable,
1009610096
)
1009710097

10098+
# cached top-level heartbeat config property (for fast validation).
10099+
step_schema.cached_heartbeat_threshold = (
10100+
step_config.config.heartbeat_healthy_threshold
10101+
if step_config.spec.enable_heartbeat
10102+
else None
10103+
)
10104+
1009810105
session.add(step_schema)
1009910106
try:
1010010107
session.commit()

tests/unit/steps/test_heartbeat_worker.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def test_heartbeat_healthiness_check(monkeypatch):
140140
is_retriable=False,
141141
start_time=None,
142142
latest_heartbeat=None,
143+
cached_heartbeat_threshold=5,
143144
),
144145
metadata=StepRunResponseMetadata(
145146
config=StepConfiguration(
@@ -189,15 +190,5 @@ def now(cls, tz=None):
189190
assert not heartbeat.is_heartbeat_unhealthy(step_run)
190191

191192
# if step heartbeat not enabled = healthy (default response)
192-
step_run.metadata.config = StepConfiguration(
193-
name="test", heartbeat_healthy_threshold=5
194-
)
195-
assert heartbeat.is_heartbeat_unhealthy(step_run)
196-
step_run.metadata.spec = StepSpec(
197-
enable_heartbeat=False,
198-
source=Source(module="test", type=SourceType.BUILTIN),
199-
upstream_steps=["test"],
200-
inputs={"test": InputSpec(step_name="test", output_name="test")},
201-
invocation_id="test",
202-
)
193+
step_run.body.cached_heartbeat_threshold = None
203194
assert not heartbeat.is_heartbeat_unhealthy(step_run)

0 commit comments

Comments
 (0)