Skip to content

Commit 6dee130

Browse files
Feature:4135 Orchestrator generic graceful stopping
1 parent dbcfe3f commit 6dee130

File tree

3 files changed

+102
-16
lines changed

3 files changed

+102
-16
lines changed

src/zenml/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4947,7 +4947,7 @@ def list_run_steps(
49474947
cache_expired: Whether the cache expiration time of the step run
49484948
has passed.
49494949
code_hash: The code hash of the step run to filter by.
4950-
status: The name of the run to filter by.
4950+
status: The status of the step run.
49514951
run_metadata: Filter by run metadata.
49524952
exclude_retried: Whether to exclude retried step runs.
49534953
hydrate: Flag deciding whether to hydrate the output model(s)

src/zenml/orchestrators/base_orchestrator.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from zenml.stack import Flavor, Stack, StackComponent, StackComponentConfig
5353
from zenml.steps.step_context import RunContext, get_or_create_run_context
5454
from zenml.utils.env_utils import temporary_environment
55+
from zenml.utils.pagination_utils import depaginate_stream
5556
from zenml.utils.pydantic_utils import before_validator_handler
5657

5758
if TYPE_CHECKING:
@@ -678,19 +679,8 @@ def stop_run(
678679
If False, forces immediate termination. Default is False.
679680
680681
Raises:
681-
NotImplementedError: If any orchestrator inheriting from the base
682-
class does not implement this logic.
683682
IllegalOperationError: If the run has no orchestrator run id yet.
684683
"""
685-
# Check if the orchestrator supports cancellation
686-
if (
687-
getattr(self._stop_run, "__func__", None)
688-
is BaseOrchestrator._stop_run
689-
):
690-
raise NotImplementedError(
691-
f"The '{self.__class__.__name__}' orchestrator does not "
692-
"support stopping pipeline runs."
693-
)
694684

695685
if not run.orchestrator_run_id:
696686
raise IllegalOperationError(
@@ -720,13 +710,84 @@ def _stop_run(
720710
run: A pipeline run response to stop (already updated to STOPPING status).
721711
graceful: If True, allows for graceful shutdown where possible.
722712
If False, forces immediate termination. Default is True.
713+
"""
714+
715+
if graceful:
716+
self._stop_run_gracefully(pipeline_run=run)
717+
else:
718+
self._stop_run_forcefully(pipeline_run=run)
719+
720+
@staticmethod
721+
def _stop_run_gracefully(pipeline_run: "PipelineRunResponse") -> None:
722+
"""Graceful pipeline shutdown.
723+
724+
Iterates running steps and sets their status to STOPPING.
725+
This is turn will cause heartbeat workers to interrupt execution of the step containers.
726+
727+
Args:
728+
pipeline_run: A pipeline run response to stop (already updated to STOPPING status).
723729
724730
Raises:
725-
NotImplementedError: If any orchestrator inheriting from the base
726-
class does not implement this logic.
731+
RuntimeError: If steps fail to be set to STOPPING or steps don't have heartbeat enabled.
727732
"""
733+
734+
from zenml.client import Client
735+
from zenml.models import StepRunUpdate
736+
737+
client = Client()
738+
739+
steps_stopped: int = 0
740+
steps_failed: int = 0
741+
steps_skipped: int = 0
742+
743+
for step in depaginate_stream(
744+
client.list_run_steps,
745+
pipeline_run_id=pipeline_run.id,
746+
status=ExecutionStatus.RUNNING,
747+
size=50,
748+
):
749+
if not step.spec.enable_heartbeat:
750+
steps_skipped += 1
751+
try:
752+
client.zen_store.update_run_step(
753+
step_run_id=step.id,
754+
step_run_update=StepRunUpdate(
755+
status=ExecutionStatus.STOPPING
756+
),
757+
)
758+
steps_stopped += 1
759+
except Exception as exc:
760+
logger.debug(
761+
"Could update step % status to %: %s",
762+
step.id,
763+
ExecutionStatus.STOPPING,
764+
str(exc),
765+
)
766+
steps_failed += 1
767+
768+
operation_summary = (
769+
f"{steps_stopped}, {steps_skipped=}, {steps_failed=}"
770+
)
771+
772+
logger.info("Graceful stopping statistics: %s", operation_summary)
773+
774+
if steps_failed or steps_skipped:
775+
if not steps_stopped:
776+
# If nothing was stopped successfully, raise an error
777+
raise RuntimeError(
778+
f"Failed to stop pipeline run: {operation_summary}"
779+
)
780+
else:
781+
# If some things were stopped but others failed, raise an error
782+
raise RuntimeError(
783+
f"Partial stop operation completed with errors: {operation_summary}"
784+
)
785+
786+
def _stop_run_forcefully(
787+
self, pipeline_run: "PipelineRunResponse"
788+
) -> None:
728789
raise NotImplementedError(
729-
"The stop run functionality is not implemented for the "
790+
"The forceful stop run functionality is not implemented for the "
730791
f"'{self.__class__.__name__}' orchestrator."
731792
)
732793

src/zenml/utils/pagination_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Pagination utilities."""
1515

16-
from typing import Any, Callable, List, TypeVar
16+
from typing import Any, Callable, Generator, List, TypeVar
1717

1818
from zenml.models import BaseIdentifiedResponse, Page
1919

@@ -40,3 +40,28 @@ def depaginate(
4040
items += list(page.items)
4141

4242
return items
43+
44+
45+
def depaginate_stream(
46+
list_method: Callable[..., Page[AnyResponse]], **kwargs: Any
47+
) -> Generator[AnyResponse, None, None]:
48+
"""Depaginate the results from a client or store method that returns pages.
49+
50+
Args:
51+
list_method: The list method to depaginate.
52+
**kwargs: Arguments for the list method.
53+
54+
Returns:
55+
A list of the corresponding Response Models.
56+
"""
57+
page = list_method(**kwargs)
58+
59+
for item in page.items:
60+
yield item
61+
62+
while page.index < page.total_pages:
63+
kwargs["page"] = page.index + 1
64+
page = list_method(**kwargs)
65+
66+
for item in page.items:
67+
yield item

0 commit comments

Comments
 (0)