|
52 | 52 | from zenml.stack import Flavor, Stack, StackComponent, StackComponentConfig |
53 | 53 | from zenml.steps.step_context import RunContext, get_or_create_run_context |
54 | 54 | from zenml.utils.env_utils import temporary_environment |
| 55 | +from zenml.utils.pagination_utils import depaginate_stream |
55 | 56 | from zenml.utils.pydantic_utils import before_validator_handler |
56 | 57 |
|
57 | 58 | if TYPE_CHECKING: |
@@ -678,19 +679,8 @@ def stop_run( |
678 | 679 | If False, forces immediate termination. Default is False. |
679 | 680 |
|
680 | 681 | Raises: |
681 | | - NotImplementedError: If any orchestrator inheriting from the base |
682 | | - class does not implement this logic. |
683 | 682 | IllegalOperationError: If the run has no orchestrator run id yet. |
684 | 683 | """ |
685 | | - # Check if the orchestrator supports cancellation |
686 | | - if ( |
687 | | - getattr(self._stop_run, "__func__", None) |
688 | | - is BaseOrchestrator._stop_run |
689 | | - ): |
690 | | - raise NotImplementedError( |
691 | | - f"The '{self.__class__.__name__}' orchestrator does not " |
692 | | - "support stopping pipeline runs." |
693 | | - ) |
694 | 684 |
|
695 | 685 | if not run.orchestrator_run_id: |
696 | 686 | raise IllegalOperationError( |
@@ -720,13 +710,84 @@ def _stop_run( |
720 | 710 | run: A pipeline run response to stop (already updated to STOPPING status). |
721 | 711 | graceful: If True, allows for graceful shutdown where possible. |
722 | 712 | If False, forces immediate termination. Default is True. |
| 713 | + """ |
| 714 | + |
| 715 | + if graceful: |
| 716 | + self._stop_run_gracefully(pipeline_run=run) |
| 717 | + else: |
| 718 | + self._stop_run_forcefully(pipeline_run=run) |
| 719 | + |
| 720 | + @staticmethod |
| 721 | + def _stop_run_gracefully(pipeline_run: "PipelineRunResponse") -> None: |
| 722 | + """Graceful pipeline shutdown. |
| 723 | +
|
| 724 | + Iterates running steps and sets their status to STOPPING. |
| 725 | + This is turn will cause heartbeat workers to interrupt execution of the step containers. |
| 726 | +
|
| 727 | + Args: |
| 728 | + pipeline_run: A pipeline run response to stop (already updated to STOPPING status). |
723 | 729 |
|
724 | 730 | Raises: |
725 | | - NotImplementedError: If any orchestrator inheriting from the base |
726 | | - class does not implement this logic. |
| 731 | + RuntimeError: If steps fail to be set to STOPPING or steps don't have heartbeat enabled. |
727 | 732 | """ |
| 733 | + |
| 734 | + from zenml.client import Client |
| 735 | + from zenml.models import StepRunUpdate |
| 736 | + |
| 737 | + client = Client() |
| 738 | + |
| 739 | + steps_stopped: int = 0 |
| 740 | + steps_failed: int = 0 |
| 741 | + steps_skipped: int = 0 |
| 742 | + |
| 743 | + for step in depaginate_stream( |
| 744 | + client.list_run_steps, |
| 745 | + pipeline_run_id=pipeline_run.id, |
| 746 | + status=ExecutionStatus.RUNNING, |
| 747 | + size=50, |
| 748 | + ): |
| 749 | + if not step.spec.enable_heartbeat: |
| 750 | + steps_skipped += 1 |
| 751 | + try: |
| 752 | + client.zen_store.update_run_step( |
| 753 | + step_run_id=step.id, |
| 754 | + step_run_update=StepRunUpdate( |
| 755 | + status=ExecutionStatus.STOPPING |
| 756 | + ), |
| 757 | + ) |
| 758 | + steps_stopped += 1 |
| 759 | + except Exception as exc: |
| 760 | + logger.debug( |
| 761 | + "Could update step % status to %: %s", |
| 762 | + step.id, |
| 763 | + ExecutionStatus.STOPPING, |
| 764 | + str(exc), |
| 765 | + ) |
| 766 | + steps_failed += 1 |
| 767 | + |
| 768 | + operation_summary = ( |
| 769 | + f"{steps_stopped}, {steps_skipped=}, {steps_failed=}" |
| 770 | + ) |
| 771 | + |
| 772 | + logger.info("Graceful stopping statistics: %s", operation_summary) |
| 773 | + |
| 774 | + if steps_failed or steps_skipped: |
| 775 | + if not steps_stopped: |
| 776 | + # If nothing was stopped successfully, raise an error |
| 777 | + raise RuntimeError( |
| 778 | + f"Failed to stop pipeline run: {operation_summary}" |
| 779 | + ) |
| 780 | + else: |
| 781 | + # If some things were stopped but others failed, raise an error |
| 782 | + raise RuntimeError( |
| 783 | + f"Partial stop operation completed with errors: {operation_summary}" |
| 784 | + ) |
| 785 | + |
| 786 | + def _stop_run_forcefully( |
| 787 | + self, pipeline_run: "PipelineRunResponse" |
| 788 | + ) -> None: |
728 | 789 | raise NotImplementedError( |
729 | | - "The stop run functionality is not implemented for the " |
| 790 | + "The forceful stop run functionality is not implemented for the " |
730 | 791 | f"'{self.__class__.__name__}' orchestrator." |
731 | 792 | ) |
732 | 793 |
|
|
0 commit comments