From 3b9d1f048cf6dd573a534fb4ab44631bd53a5799 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 25 Nov 2025 14:35:07 +0800 Subject: [PATCH 1/2] Add support for dynamic pipelines to the Vertex orchestrator --- .../steps-pipelines/dynamic_pipelines.md | 21 +- .../orchestrators/sagemaker_orchestrator.py | 4 +- .../gcp/orchestrators/vertex_orchestrator.py | 369 ++++++++++++++---- .../step_operators/vertex_step_operator.py | 177 ++------- src/zenml/integrations/gcp/utils.py | 215 ++++++++++ 5 files changed, 534 insertions(+), 252 deletions(-) create mode 100644 src/zenml/integrations/gcp/utils.py diff --git a/docs/book/how-to/steps-pipelines/dynamic_pipelines.md b/docs/book/how-to/steps-pipelines/dynamic_pipelines.md index cbb7da4e2ac..5a767c9d51a 100644 --- a/docs/book/how-to/steps-pipelines/dynamic_pipelines.md +++ b/docs/book/how-to/steps-pipelines/dynamic_pipelines.md @@ -5,7 +5,7 @@ description: Write dynamic pipelines # Dynamic Pipelines (Experimental) {% hint style="warning" %} -**Experimental Feature**: Dynamic pipelines are currently an experimental feature. There are known issues and limitations, and the interface is subject to change. This feature is only supported by the `local` and `kubernetes` orchestrators. If you encounter any issues or have feedback, please let us know at [https://github.com/zenml-io/zenml/issues](https://github.com/zenml-io/zenml/issues). +**Experimental Feature**: Dynamic pipelines are currently an experimental feature. There are known issues and limitations, and the interface is subject to change. This feature is only supported by the `local`, `kubernetes`, `sagemaker` and `vertex` orchestrators. If you encounter any issues or have feedback, please let us know at [https://github.com/zenml-io/zenml/issues](https://github.com/zenml-io/zenml/issues). {% endhint %} {% hint style="info" %} @@ -265,26 +265,11 @@ When running multiple steps concurrently using `step.submit()`, a failure in one Dynamic pipelines are currently only supported by: - `local` orchestrator - `kubernetes` orchestrator +- `sagemaker` orchestrator +- `vertex` orchestrator Other orchestrators will raise an error if you try to run a dynamic pipeline with them. -### Remote Execution Requirement - -When running dynamic pipelines remotely (e.g., with the `kubernetes` orchestrator), you **must** include `depends_on` for at least one step in your pipeline definition. This is currently required due to a bug in remote execution. - -{% hint style="warning" %} -**Required for Remote Execution**: Without `depends_on`, remote execution will fail. This requirement does not apply when running locally with the `local` orchestrator. -{% endhint %} - -For example: - -```python -@pipeline(dynamic=True, depends_on=[some_step]) -def dynamic_pipeline(): - some_step() - # ... rest of your pipeline -``` - ### Artifact Loading When you call `.load()` on an artifact in a dynamic pipeline, it synchronously loads the data. For large artifacts or when you want to maintain parallelism, consider passing the step outputs (future or artifact) directly to downstream steps instead of loading them. diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index efc9ca1ebfc..5be8976bb5d 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -1050,10 +1050,10 @@ def _wait_for_completion() -> None: metadata=metadata, ) - def launch_dynamic_step( + def run_isolated_step( self, step_run_info: "StepRunInfo", environment: Dict[str, str] ) -> None: - """Launch a dynamic step. + """Runs an isolated step on Sagemaker. Args: step_run_info: The step run information. diff --git a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py index 27c85080483..f3b60076fde 100644 --- a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +++ b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py @@ -41,6 +41,7 @@ Optional, Tuple, Type, + Union, cast, ) from uuid import UUID @@ -51,7 +52,7 @@ pipeline_service_client_v1beta1, ) from google.cloud.aiplatform.compat.types import pipeline_job_v1beta1 -from google.cloud.aiplatform_v1.types import PipelineState +from google.cloud.aiplatform_v1.types import JobState, PipelineState from google.cloud.aiplatform_v1beta1.types.service_networking import ( PscInterfaceConfig, ) @@ -63,17 +64,20 @@ from kfp.compiler import Compiler from kfp.dsl.base_component import BaseComponent +from zenml import __version__ from zenml.config.resource_settings import ResourceSettings from zenml.constants import ( METADATA_ORCHESTRATOR_LOGS_URL, METADATA_ORCHESTRATOR_RUN_ID, METADATA_ORCHESTRATOR_URL, + ORCHESTRATOR_DOCKER_IMAGE_KEY, ) from zenml.entrypoints import StepEntrypointConfiguration from zenml.enums import ExecutionStatus, StackComponentType from zenml.integrations.gcp import GCP_ARTIFACT_STORE_FLAVOR from zenml.integrations.gcp.constants import ( GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL, + VERTEX_ENDPOINT_SUFFIX, ) from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import ( VertexOrchestratorConfig, @@ -82,6 +86,10 @@ from zenml.integrations.gcp.google_credentials_mixin import ( GoogleCredentialsMixin, ) +from zenml.integrations.gcp.utils import ( + build_job_request, + monitor_job, +) from zenml.integrations.gcp.vertex_custom_job_parameters import ( VertexCustomJobParameters, ) @@ -90,11 +98,18 @@ from zenml.metadata.metadata_types import MetadataType, Uri from zenml.orchestrators import ContainerizedOrchestrator, SubmissionResult from zenml.orchestrators.utils import get_orchestrator_run_name +from zenml.pipelines.dynamic.entrypoint_configuration import ( + DynamicPipelineEntrypointConfiguration, +) from zenml.stack.stack_validator import StackValidator +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) from zenml.utils.io_utils import get_global_config_directory if TYPE_CHECKING: from zenml.config.base_settings import BaseSettings + from zenml.config.step_run_info import StepRunInfo from zenml.models import ( PipelineRunResponse, PipelineSnapshotResponse, @@ -621,6 +636,155 @@ def dynamic_pipeline() -> None: schedule=snapshot.schedule, ) + def submit_dynamic_pipeline( + self, + snapshot: "PipelineSnapshotResponse", + stack: "Stack", + environment: Dict[str, str], + placeholder_run: Optional["PipelineRunResponse"] = None, + ) -> Optional[SubmissionResult]: + """Submits a dynamic pipeline to the orchestrator. + + Args: + snapshot: The pipeline snapshot to submit. + stack: The stack the pipeline will run on. + environment: Environment variables to set in the orchestration + environment. + placeholder_run: An optional placeholder run. + + Raises: + RuntimeError: If the snapshot contains a schedule. + + Returns: + Optional submission result. + """ + if snapshot.schedule: + raise RuntimeError( + "Scheduling dynamic pipelines is not supported for the " + "Vertex orchestrator yet." + ) + + settings = cast( + VertexOrchestratorSettings, self.get_settings(snapshot) + ) + + command = ( + DynamicPipelineEntrypointConfiguration.get_entrypoint_command() + ) + args = DynamicPipelineEntrypointConfiguration.get_entrypoint_arguments( + snapshot_id=snapshot.id, + run_id=placeholder_run.id if placeholder_run else None, + ) + + image = self.get_image(snapshot=snapshot) + labels = settings.labels.copy() + labels["source"] = f"zenml-{__version__.replace('.', '_')}" + + job_request = build_job_request( + display_name=get_orchestrator_run_name( + pipeline_name=snapshot.pipeline_configuration.name + ), + image=image, + entrypoint_command=command + args, + custom_job_settings=settings.custom_job_parameters + or VertexCustomJobParameters(), + resource_settings=snapshot.pipeline_configuration.resource_settings, + environment=environment, + labels=labels, + encryption_spec_key_name=self.config.encryption_spec_key_name, + service_account=self.config.workload_service_account, + network=self.config.network, + ) + + credentials, project_id = self._get_authentication() + client_options = { + "api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX + } + client = aiplatform.gapic.JobServiceClient( + credentials=credentials, client_options=client_options + ) + parent = f"projects/{project_id}/locations/{self.config.location}" + job_model = client.create_custom_job( + parent=parent, custom_job=job_request + ) + + wait_for_completion = None + if settings.synchronous: + wait_for_completion = lambda: monitor_job( + job_id=job_model.name, + credentials_source=self, + client_options=client_options, + ) + + self._initialize_vertex_client() + job = aiplatform.CustomJob.get(job_model.name) + metadata = self.compute_metadata(job) + + logger.info("View the Vertex job at %s", job._dashboard_uri()) + + return SubmissionResult( + wait_for_completion=wait_for_completion, + metadata=metadata, + ) + + def run_isolated_step( + self, step_run_info: "StepRunInfo", environment: Dict[str, str] + ) -> None: + """Runs an isolated step on Vertex. + + Args: + step_run_info: The step run information. + environment: The environment variables to set. + """ + settings = cast( + VertexOrchestratorSettings, self.get_settings(step_run_info) + ) + + image = step_run_info.get_image(key=ORCHESTRATOR_DOCKER_IMAGE_KEY) + command = StepOperatorEntrypointConfiguration.get_entrypoint_command() + args = StepOperatorEntrypointConfiguration.get_entrypoint_arguments( + step_name=step_run_info.pipeline_step_name, + snapshot_id=(step_run_info.snapshot.id), + step_run_id=str(step_run_info.step_run_id), + ) + + labels = settings.labels.copy() + labels["source"] = f"zenml-{__version__.replace('.', '_')}" + + job_request = build_job_request( + display_name=f"{step_run_info.run_name}-{step_run_info.pipeline_step_name}", + image=image, + entrypoint_command=command + args, + custom_job_settings=settings.custom_job_parameters + or VertexCustomJobParameters(), + resource_settings=step_run_info.config.resource_settings, + environment=environment, + labels=labels, + encryption_spec_key_name=self.config.encryption_spec_key_name, + service_account=self.config.workload_service_account, + network=self.config.network, + ) + + credentials, project_id = self._get_authentication() + client_options = { + "api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX + } + client = aiplatform.gapic.JobServiceClient( + credentials=credentials, client_options=client_options + ) + parent = f"projects/{project_id}/locations/{self.config.location}" + logger.info( + "Submitting custom job='%s', path='%s' to Vertex AI Training.", + job_request["display_name"], + parent, + ) + job = client.create_custom_job(parent=parent, custom_job=job_request) + monitor_job( + job_id=job.name, + credentials_source=self, + client_options=client_options, + ) + def _upload_and_run_pipeline( self, pipeline_name: str, @@ -786,19 +950,19 @@ def get_orchestrator_run_id(self) -> str: """Returns the active orchestrator run id. Raises: - RuntimeError: If the environment variable specifying the run id - is not set. + RuntimeError: If the orchestrator run id cannot be read from the + environment. Returns: The orchestrator run id. """ - try: - return os.environ[ENV_ZENML_VERTEX_RUN_ID] - except KeyError: - raise RuntimeError( - "Unable to read run id from environment variable " - f"{ENV_ZENML_VERTEX_RUN_ID}." - ) + for env in [ENV_ZENML_VERTEX_RUN_ID, "CLOUD_ML_JOB_ID"]: + if env in os.environ: + return os.environ[env] + + raise RuntimeError( + "Unable to get orchestrator run id from environment." + ) def get_pipeline_run_metadata( self, run_id: UUID @@ -811,14 +975,26 @@ def get_pipeline_run_metadata( Returns: A dictionary of metadata. """ - run_url = ( - f"https://console.cloud.google.com/vertex-ai/locations/" - f"{self.config.location}/pipelines/runs/" - f"{self.get_orchestrator_run_id()}" - ) + if ENV_ZENML_VERTEX_RUN_ID in os.environ: + # Static pipeline -> Pipeline job + run_url = ( + f"https://console.cloud.google.com/vertex-ai/locations/" + f"{self.config.location}/pipelines/runs/" + f"{self.get_orchestrator_run_id()}" + ) + else: + # Dynamic pipeline -> Custom job + run_url = ( + f"https://console.cloud.google.com/vertex-ai/locations/" + f"{self.config.location}/training/" + f"{self.get_orchestrator_run_id()}" + ) + if self.config.project: run_url += f"?project={self.config.project}" + return { + METADATA_ORCHESTRATOR_RUN_ID: self.get_orchestrator_run_id(), METADATA_ORCHESTRATOR_URL: Uri(run_url), } @@ -884,6 +1060,15 @@ def _configure_container_resources( return dynamic_component + def _initialize_vertex_client(self) -> None: + """Initializes the Vertex client.""" + credentials, project_id = self._get_authentication() + aiplatform.init( + project=project_id, + location=self.config.location, + credentials=credentials, + ) + def fetch_status( self, run: "PipelineRunResponse", include_steps: bool = False ) -> Tuple[ @@ -917,13 +1102,7 @@ def fetch_status( == run.stack.components[StackComponentType.ORCHESTRATOR][0].id ) - # Initialize the Vertex client - credentials, project_id = self._get_authentication() - aiplatform.init( - project=project_id, - location=self.config.location, - credentials=credentials, - ) + self._initialize_vertex_client() # Fetch the status of the PipelineJob if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata: @@ -935,58 +1114,84 @@ def fetch_status( "Can not find the orchestrator run ID, thus can not fetch " "the status." ) - status = aiplatform.PipelineJob.get(run_id).state - - # Map the potential outputs to ZenML ExecutionStatus. Potential values: - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html# - if status == PipelineState.PIPELINE_STATE_UNSPECIFIED: - pipeline_status = run.status - elif status in [ - PipelineState.PIPELINE_STATE_QUEUED, - PipelineState.PIPELINE_STATE_PENDING, - ]: - pipeline_status = ExecutionStatus.INITIALIZING - elif status in [ - PipelineState.PIPELINE_STATE_RUNNING, - PipelineState.PIPELINE_STATE_PAUSED, - ]: - pipeline_status = ExecutionStatus.RUNNING - elif status == PipelineState.PIPELINE_STATE_SUCCEEDED: - pipeline_status = ExecutionStatus.COMPLETED - elif status == PipelineState.PIPELINE_STATE_CANCELLING: - pipeline_status = ExecutionStatus.STOPPING - elif status == PipelineState.PIPELINE_STATE_CANCELLED: - pipeline_status = ExecutionStatus.STOPPED - elif status == PipelineState.PIPELINE_STATE_FAILED: - pipeline_status = ExecutionStatus.FAILED + + if run.snapshot and run.snapshot.is_dynamic: + status = aiplatform.CustomJob.get(run_id).state + + if status in [ + JobState.JOB_STATE_QUEUED, + JobState.JOB_STATE_PENDING, + ]: + pipeline_status = ExecutionStatus.PROVISIONING + elif status in [ + JobState.JOB_STATE_RUNNING, + JobState.JOB_STATE_PAUSED, + JobState.JOB_STATE_UPDATING, + ]: + pipeline_status = ExecutionStatus.RUNNING + elif status == JobState.JOB_STATE_SUCCEEDED: + pipeline_status = ExecutionStatus.COMPLETED + elif status == JobState.JOB_STATE_CANCELLING: + pipeline_status = ExecutionStatus.STOPPING + elif status == JobState.JOB_STATE_CANCELLED: + pipeline_status = ExecutionStatus.STOPPED + elif status in [ + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_EXPIRED, + ]: + pipeline_status = ExecutionStatus.FAILED + else: + pipeline_status = run.status else: - raise ValueError("Unknown status for the pipeline job.") + status = aiplatform.PipelineJob.get(run_id).state + + # Map the potential outputs to ZenML ExecutionStatus. Potential values: + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html# + if status == PipelineState.PIPELINE_STATE_UNSPECIFIED: + pipeline_status = run.status + elif status in [ + PipelineState.PIPELINE_STATE_QUEUED, + PipelineState.PIPELINE_STATE_PENDING, + ]: + pipeline_status = ExecutionStatus.PROVISIONING + elif status in [ + PipelineState.PIPELINE_STATE_RUNNING, + PipelineState.PIPELINE_STATE_PAUSED, + ]: + pipeline_status = ExecutionStatus.RUNNING + elif status == PipelineState.PIPELINE_STATE_SUCCEEDED: + pipeline_status = ExecutionStatus.COMPLETED + elif status == PipelineState.PIPELINE_STATE_CANCELLING: + pipeline_status = ExecutionStatus.STOPPING + elif status == PipelineState.PIPELINE_STATE_CANCELLED: + pipeline_status = ExecutionStatus.STOPPED + elif status == PipelineState.PIPELINE_STATE_FAILED: + pipeline_status = ExecutionStatus.FAILED + else: + raise ValueError("Unknown status for the pipeline job.") # Vertex doesn't support step-level status fetching yet return pipeline_status, None def compute_metadata( - self, job: aiplatform.PipelineJob + self, job: Union[aiplatform.PipelineJob, aiplatform.CustomJob] ) -> Dict[str, MetadataType]: - """Generate run metadata based on the corresponding Vertex PipelineJob. + """Generate run metadata based on the Vertex job. Args: - job: The corresponding PipelineJob object. + job: The job. Returns: A dictionary of metadata related to the pipeline run. """ metadata: Dict[str, MetadataType] = {} - # Orchestrator Run ID if run_id := self._compute_orchestrator_run_id(job): metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id - # URL to the Vertex's pipeline view if orchestrator_url := self._compute_orchestrator_url(job): metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url) - # URL to the corresponding Logs Explorer page if logs_url := self._compute_orchestrator_logs_url(job): metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url) @@ -994,70 +1199,76 @@ def compute_metadata( @staticmethod def _compute_orchestrator_url( - job: aiplatform.PipelineJob, + job: Union[aiplatform.PipelineJob, aiplatform.CustomJob], ) -> Optional[str]: - """Generate the Orchestrator Dashboard URL upon pipeline execution. + """Generate the Orchestrator Dashboard URL. Args: - job: The corresponding PipelineJob object. + job: The job. Returns: - the URL to the dashboard view in Vertex. + The Vertex Dashboard URL for the job. """ try: - return str(job._dashboard_uri()) + if uri := job._dashboard_uri(): + return Uri(uri) except Exception as e: logger.warning( - f"There was an issue while extracting the pipeline url: {e}" + "There was an issue while extracting the job dashboard URL: %s", + e, ) - return None + return None @staticmethod def _compute_orchestrator_logs_url( - job: aiplatform.PipelineJob, + job: Union[aiplatform.PipelineJob, aiplatform.CustomJob], ) -> Optional[str]: - """Generate the Logs Explorer URL upon pipeline execution. + """Generate the Logs Explorer URL. Args: - job: The corresponding PipelineJob object. + job: The job. Returns: - the URL querying the pipeline logs in Logs Explorer on GCP. + The Logs Explorer URL for the job. """ try: - base_url = "https://console.cloud.google.com/logs/query" - query = f""" - resource.type="aiplatform.googleapis.com/PipelineJob" - resource.labels.pipeline_job_id="{job.job_id}" - """ - encoded_query = urllib.parse.quote(query) - return f"{base_url}?project={job.project}&query={encoded_query}" + if isinstance(job, aiplatform.PipelineJob): + query = f""" + resource.type="aiplatform.googleapis.com/PipelineJob" + resource.labels.pipeline_job_id="{job.job_id}" + """ + else: + query = f'resource.labels.job_id="{job.name}"' + query = urllib.parse.quote(query) except Exception as e: logger.warning( f"There was an issue while extracting the logs url: {e}" ) return None + else: + return f"https://console.cloud.google.com/logs/query?project={job.project}&query={query}" @staticmethod def _compute_orchestrator_run_id( - job: aiplatform.PipelineJob, + job: Union[aiplatform.PipelineJob, aiplatform.CustomJob], ) -> Optional[str]: - """Fetch the Orchestrator Run ID upon pipeline execution. + """Fetch the orchestrator run ID. Args: - job: The corresponding PipelineJob object. + job: The job. Returns: - the Execution ID of the run in Vertex. + The orchestrator run ID. """ try: - if job.job_id: + if isinstance(job, aiplatform.PipelineJob): return str(job.job_id) - - return None + else: + return str(job.name) except Exception as e: logger.warning( - f"There was an issue while extracting the pipeline run ID: {e}" + "There was an issue while extracting the orchestrator run ID: %s", + e, ) return None diff --git a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py index 427b4b8c7b0..3bf1f53c9ae 100644 --- a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py +++ b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py @@ -18,21 +18,15 @@ google_cloud_ai_platform/training_clients.py """ -import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast -from google.api_core.exceptions import ServerError from google.cloud import aiplatform from zenml import __version__ from zenml.config.build_configuration import BuildConfiguration from zenml.enums import StackComponentType from zenml.integrations.gcp.constants import ( - CONNECTION_ERROR_RETRY_LIMIT, - POLLING_INTERVAL_IN_SECONDS, VERTEX_ENDPOINT_SUFFIX, - VERTEX_JOB_STATES_COMPLETED, - VERTEX_JOB_STATES_FAILED, ) from zenml.integrations.gcp.flavors.vertex_step_operator_flavor import ( VertexStepOperatorConfig, @@ -41,6 +35,7 @@ from zenml.integrations.gcp.google_credentials_mixin import ( GoogleCredentialsMixin, ) +from zenml.integrations.gcp.utils import build_job_request, monitor_job from zenml.logger import get_logger from zenml.stack import Stack, StackValidator from zenml.step_operators import BaseStepOperator @@ -55,22 +50,6 @@ VERTEX_DOCKER_IMAGE_KEY = "vertex_step_operator" -def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None: - """Validates that the accelerator type is valid. - - Args: - accelerator_type: The accelerator type to validate. - - Raises: - ValueError: If the accelerator type is not valid. - """ - accepted_vals = list(aiplatform.gapic.AcceleratorType.__members__.keys()) - if accelerator_type and accelerator_type.upper() not in accepted_vals: - raise ValueError( - f"Accelerator must be one of the following: {accepted_vals}" - ) - - class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin): """Step operator to run a step on Vertex AI. @@ -184,9 +163,6 @@ def launch( entrypoint_command: Command that executes the step. environment: Environment variables to set in the step operator environment. - - Raises: - RuntimeError: If the run fails. """ resource_settings = info.config.resource_settings if resource_settings.cpu_count or resource_settings.memory: @@ -200,150 +176,45 @@ def launch( self.name, ) settings = cast(VertexStepOperatorSettings, self.get_settings(info)) - validate_accelerator_type(settings.accelerator_type) - - job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"} + image = info.get_image(key=VERTEX_DOCKER_IMAGE_KEY) + + labels = {"source": f"zenml-{__version__.replace('.', '_')}"} + job_request = build_job_request( + display_name=f"{info.run_name}-{info.pipeline_step_name}", + image=image, + entrypoint_command=entrypoint_command, + custom_job_settings=settings, + resource_settings=info.config.resource_settings, + environment=environment, + labels=labels, + encryption_spec_key_name=self.config.encryption_spec_key_name, + service_account=self.config.service_account, + network=self.config.network, + ) + logger.debug("Vertex AI Job=%s", job_request) - # Step 1: Authenticate with Google credentials, project_id = self._get_authentication() - - image_name = info.get_image(key=VERTEX_DOCKER_IMAGE_KEY) - - # Step 3: Launch the job - # The AI Platform services require regional API endpoints. client_options = { "api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX } - # Initialize client that will be used to create and send requests. - # This client only needs to be created once, and can be reused for - # multiple requests. client = aiplatform.gapic.JobServiceClient( credentials=credentials, client_options=client_options ) - accelerator_count = ( - resource_settings.gpu_count or settings.accelerator_count - ) - custom_job = { - "display_name": info.run_name, - "job_spec": { - "worker_pool_specs": [ - { - "machine_spec": { - "machine_type": settings.machine_type, - "accelerator_type": settings.accelerator_type, - "accelerator_count": accelerator_count - if settings.accelerator_type - else 0, - }, - "replica_count": 1, - "container_spec": { - "image_uri": image_name, - "command": entrypoint_command, - "args": [], - "env": [ - {"name": key, "value": value} - for key, value in environment.items() - ], - }, - "disk_spec": { - "boot_disk_type": settings.boot_disk_type, - "boot_disk_size_gb": settings.boot_disk_size_gb, - }, - } - ], - "service_account": self.config.service_account, - "network": self.config.network, - "reserved_ip_ranges": ( - self.config.reserved_ip_ranges.split(",") - if self.config.reserved_ip_ranges - else [] - ), - "persistent_resource_id": settings.persistent_resource_id, - }, - "labels": job_labels, - "encryption_spec": { - "kmsKeyName": self.config.encryption_spec_key_name - } - if self.config.encryption_spec_key_name - else {}, - } - logger.debug("Vertex AI Job=%s", custom_job) parent = f"projects/{project_id}/locations/{self.config.region}" logger.info( "Submitting custom job='%s', path='%s' to Vertex AI Training.", - custom_job["display_name"], + job_request["display_name"], parent, ) info.force_write_logs() response = client.create_custom_job( - parent=parent, custom_job=custom_job + parent=parent, custom_job=job_request ) logger.debug("Vertex AI response:", response) - # Step 4: Monitor the job - - # Monitors the long-running operation by polling the job state - # periodically, and retries the polling when a transient connectivity - # issue is encountered. - # - # Long-running operation monitoring: - # The possible states of "get job" response can be found at - # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State - # where SUCCEEDED/FAILED/CANCELED are considered to be final states. - # The following logic will keep polling the state of the job until - # the job enters a final state. - # - # During the polling, if a connection error was encountered, the GET - # request will be retried by recreating the Python API client to - # refresh the lifecycle of the connection being used. See - # https://github.com/googleapis/google-api-python-client/issues/218 - # for a detailed description of the problem. If the error persists for - # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function - # will raise ConnectionError. - retry_count = 0 - job_id = response.name - - while response.state not in VERTEX_JOB_STATES_COMPLETED: - time.sleep(POLLING_INTERVAL_IN_SECONDS) - if self.connector_has_expired(): - logger.warning("Connector has expired. Recreating client...") - # This call will refresh the credentials if they expired. - credentials, project_id = self._get_authentication() - # Recreate the Python API client. - client = aiplatform.gapic.JobServiceClient( - credentials=credentials, client_options=client_options - ) - try: - response = client.get_custom_job(name=job_id) - retry_count = 0 - # Handle transient connection errors and credential expiration by - # recreating the Python API client. - except (ConnectionError, ServerError) as err: - if retry_count < CONNECTION_ERROR_RETRY_LIMIT: - retry_count += 1 - logger.warning( - f"Error encountered when polling job " - f"{job_id}: {err}\nRetrying...", - ) - continue - else: - logger.exception( - "Request failed after %s retries.", - CONNECTION_ERROR_RETRY_LIMIT, - ) - raise RuntimeError( - f"Request failed after {CONNECTION_ERROR_RETRY_LIMIT} " - f"retries: {err}" - ) - if response.state in VERTEX_JOB_STATES_FAILED: - err_msg = ( - "Job '{}' did not succeed. Detailed response {}.".format( - job_id, response - ) - ) - logger.error(err_msg) - raise RuntimeError(err_msg) - - # Cloud training complete - logger.info("Job '%s' successful.", job_id) + monitor_job( + job_id=response.name, + credentials_source=self, + client_options=client_options, + ) diff --git a/src/zenml/integrations/gcp/utils.py b/src/zenml/integrations/gcp/utils.py new file mode 100644 index 00000000000..5ab0042c87b --- /dev/null +++ b/src/zenml/integrations/gcp/utils.py @@ -0,0 +1,215 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Vertex utilities.""" + +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from google.api_core.exceptions import ServerError +from google.cloud import aiplatform + +from zenml.integrations.gcp.constants import ( + CONNECTION_ERROR_RETRY_LIMIT, + POLLING_INTERVAL_IN_SECONDS, + VERTEX_JOB_STATES_COMPLETED, + VERTEX_JOB_STATES_FAILED, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsMixin, +) +from zenml.integrations.gcp.vertex_custom_job_parameters import ( + VertexCustomJobParameters, +) +from zenml.logger import get_logger + +if TYPE_CHECKING: + from zenml.config.resource_settings import ResourceSettings + +logger = get_logger(__name__) + + +def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None: + """Validates that the accelerator type is valid. + + Args: + accelerator_type: The accelerator type to validate. + + Raises: + ValueError: If the accelerator type is not valid. + """ + accepted_vals = list(aiplatform.gapic.AcceleratorType.__members__.keys()) + if accelerator_type and accelerator_type.upper() not in accepted_vals: + raise ValueError( + f"Accelerator must be one of the following: {accepted_vals}" + ) + + +def get_job_service_client( + credentials_source: GoogleCredentialsMixin, + client_options: Optional[Dict[str, Any]] = None, +) -> aiplatform.gapic.JobServiceClient: + """Gets a job service client. + + Args: + credentials_source: The component that provides the credentials to + access the job. + client_options: The client options to use for the job service client. + + Returns: + A job service client. + """ + credentials, _ = credentials_source._get_authentication() + return aiplatform.gapic.JobServiceClient( + credentials=credentials, client_options=client_options + ) + + +def monitor_job( + job_id: str, + credentials_source: GoogleCredentialsMixin, + client_options: Optional[Dict[str, Any]] = None, +) -> None: + """Monitors a job until it is completed. + + Args: + job_id: The ID of the job to monitor. + credentials_source: The component that provides the credentials to + access the job. + client_options: The client options to use for the job service client. + + Raises: + RuntimeError: If the job fails. + """ + retry_count = 0 + client = get_job_service_client( + credentials_source=credentials_source, client_options=client_options + ) + + while True: + time.sleep(POLLING_INTERVAL_IN_SECONDS) + if credentials_source.connector_has_expired(): + client = get_job_service_client( + credentials_source=credentials_source, + client_options=client_options, + ) + + try: + response = client.get_custom_job(name=job_id) + retry_count = 0 + except (ConnectionError, ServerError) as err: + if retry_count < CONNECTION_ERROR_RETRY_LIMIT: + retry_count += 1 + logger.warning( + f"Error encountered when polling job " + f"{job_id}: {err}\nRetrying...", + ) + continue + else: + logger.exception( + "Request failed after %s retries.", + CONNECTION_ERROR_RETRY_LIMIT, + ) + raise RuntimeError( + f"Request failed after {CONNECTION_ERROR_RETRY_LIMIT} " + f"retries: {err}" + ) + if response.state in VERTEX_JOB_STATES_FAILED: + err_msg = f"Job `{job_id}` failed: {response}." + logger.error(err_msg) + raise RuntimeError(err_msg) + if response.state in VERTEX_JOB_STATES_COMPLETED: + break + + logger.info("Job `%s` successful.", job_id) + + +def build_job_request( + display_name: str, + image: str, + entrypoint_command: List[str], + custom_job_settings: VertexCustomJobParameters, + resource_settings: "ResourceSettings", + environment: Optional[Dict[str, str]] = None, + labels: Optional[Dict[str, str]] = None, + encryption_spec_key_name: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + reserved_ip_ranges: Optional[str] = None, +) -> Dict[str, Any]: + """Build a job request. + + Args: + display_name: The display name of the job. + image: The image URI of the job. + entrypoint_command: The entrypoint command of the job. + custom_job_settings: The custom job settings. + resource_settings: The resource settings. + environment: The environment variables. + labels: The labels. + encryption_spec_key_name: The encryption spec key name. + service_account: The service account. + network: The network. + reserved_ip_ranges: The reserved IP ranges. + + Returns: + Job request dictionary. + """ + environment = environment or {} + labels = labels or {} + + validate_accelerator_type(custom_job_settings.accelerator_type) + + accelerator_count = ( + resource_settings.gpu_count or custom_job_settings.accelerator_count + ) + return { + "display_name": display_name, + "job_spec": { + "worker_pool_specs": [ + { + "machine_spec": { + "machine_type": custom_job_settings.machine_type, + "accelerator_type": custom_job_settings.accelerator_type, + "accelerator_count": accelerator_count + if custom_job_settings.accelerator_type + else 0, + }, + "replica_count": 1, + "container_spec": { + "image_uri": image, + "command": entrypoint_command, + "args": [], + "env": [ + {"name": key, "value": value} + for key, value in environment.items() + ], + }, + "disk_spec": { + "boot_disk_type": custom_job_settings.boot_disk_type, + "boot_disk_size_gb": custom_job_settings.boot_disk_size_gb, + }, + } + ], + "service_account": service_account, + "network": network, + "reserved_ip_ranges": ( + reserved_ip_ranges.split(",") if reserved_ip_ranges else [] + ), + "persistent_resource_id": custom_job_settings.persistent_resource_id, + }, + "labels": labels, + "encryption_spec": {"kmsKeyName": encryption_spec_key_name} + if encryption_spec_key_name + else {}, + } From 1e365337f47c115c39875ed78d293426254b9e7f Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 27 Nov 2025 19:41:46 +0800 Subject: [PATCH 2/2] Address review comments --- .../gcp/google_credentials_mixin.py | 25 +++++ .../gcp/orchestrators/vertex_orchestrator.py | 95 +++++++++++-------- .../step_operators/vertex_step_operator.py | 35 +++++-- src/zenml/integrations/gcp/utils.py | 45 ++------- 4 files changed, 115 insertions(+), 85 deletions(-) diff --git a/src/zenml/integrations/gcp/google_credentials_mixin.py b/src/zenml/integrations/gcp/google_credentials_mixin.py index 681bbf4da66..2b3065b650e 100644 --- a/src/zenml/integrations/gcp/google_credentials_mixin.py +++ b/src/zenml/integrations/gcp/google_credentials_mixin.py @@ -48,6 +48,9 @@ class GoogleCredentialsConfigMixin(StackComponentConfig): class GoogleCredentialsMixin(StackComponent): """StackComponent mixin to get Google Cloud Platform credentials.""" + _gcp_credentials: Optional["Credentials"] = None + _gcp_project_id: Optional[str] = None + @property def config(self) -> GoogleCredentialsConfigMixin: """Returns the `GoogleCredentialsConfigMixin` config. @@ -57,6 +60,18 @@ def config(self) -> GoogleCredentialsConfigMixin: """ return cast(GoogleCredentialsConfigMixin, self._config) + @property + def gcp_project_id(self) -> str: + """Get the GCP project ID. + + Returns: + The GCP project ID. + """ + if self._gcp_project_id is None: + _, self._gcp_project_id = self._get_authentication() + + return self._gcp_project_id + def _get_authentication(self) -> Tuple["Credentials", str]: """Get GCP credentials and the project ID associated with the credentials. @@ -79,6 +94,12 @@ def _get_authentication(self) -> Tuple["Credentials", str]: GCPServiceConnector, ) + if self.connector_has_expired(): + self._gcp_credentials = None + + if self._gcp_credentials and self._gcp_project_id: + return self._gcp_credentials, self._gcp_project_id + connector = self.get_connector() if connector: credentials = connector.connect() @@ -90,6 +111,8 @@ def _get_authentication(self) -> Tuple["Credentials", str]: "trying to use the linked connector, but got " f"{type(credentials)}." ) + self._gcp_credentials = credentials + self._gcp_project_id = connector.config.gcp_project_id return credentials, connector.config.gcp_project_id if self.config.service_account_path: @@ -111,4 +134,6 @@ def _get_authentication(self) -> Tuple["Credentials", str]: # If the project was set in the configuration, use it. Otherwise, use # the project that was used to authenticate. project_id = self.config.project if self.config.project else project_id + self._gcp_credentials = credentials + self._gcp_project_id = project_id return credentials, project_id diff --git a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py index f3b60076fde..510f4e7acc1 100644 --- a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +++ b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py @@ -144,6 +144,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin): """Orchestrator responsible for running pipelines on Vertex AI.""" _pipeline_root: str + _job_service_client: Optional[aiplatform.gapic.JobServiceClient] = None @property def config(self) -> VertexOrchestratorConfig: @@ -261,6 +262,25 @@ def pipeline_directory(self) -> str: """ return os.path.join(self.root_directory, "pipelines") + def get_job_service_client(self) -> aiplatform.gapic.JobServiceClient: + """Get the job service client. + + Returns: + The job service client. + """ + if self.connector_has_expired(): + self._job_service_client = None + + if self._job_service_client is None: + credentials, _ = self._get_authentication() + client_options = { + "api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX + } + self._job_service_client = aiplatform.gapic.JobServiceClient( + credentials=credentials, client_options=client_options + ) + return self._job_service_client + def _create_container_component( self, image: str, @@ -696,34 +716,38 @@ def submit_dynamic_pipeline( network=self.config.network, ) - credentials, project_id = self._get_authentication() - client_options = { - "api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX - } - client = aiplatform.gapic.JobServiceClient( - credentials=credentials, client_options=client_options + client = self.get_job_service_client() + parent = ( + f"projects/{self.gcp_project_id}/locations/{self.config.location}" ) - parent = f"projects/{project_id}/locations/{self.config.location}" job_model = client.create_custom_job( parent=parent, custom_job=job_request ) - wait_for_completion = None + _wait_for_completion = None if settings.synchronous: - wait_for_completion = lambda: monitor_job( - job_id=job_model.name, - credentials_source=self, - client_options=client_options, - ) - self._initialize_vertex_client() - job = aiplatform.CustomJob.get(job_model.name) + def _wait_for_completion() -> None: + logger.info("Waiting for the VertexAI job to finish...") + monitor_job( + job_id=job_model.name, + get_client=self.get_job_service_client, + ) + logger.info("VertexAI job completed successfully.") + + credentials, project_id = self._get_authentication() + job = aiplatform.CustomJob.get( + job_model.name, + project=project_id, + location=self.config.location, + credentials=credentials, + ) metadata = self.compute_metadata(job) logger.info("View the Vertex job at %s", job._dashboard_uri()) return SubmissionResult( - wait_for_completion=wait_for_completion, + wait_for_completion=_wait_for_completion, metadata=metadata, ) @@ -765,14 +789,10 @@ def run_isolated_step( network=self.config.network, ) - credentials, project_id = self._get_authentication() - client_options = { - "api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX - } - client = aiplatform.gapic.JobServiceClient( - credentials=credentials, client_options=client_options + client = self.get_job_service_client() + parent = ( + f"projects/{self.gcp_project_id}/locations/{self.config.location}" ) - parent = f"projects/{project_id}/locations/{self.config.location}" logger.info( "Submitting custom job='%s', path='%s' to Vertex AI Training.", job_request["display_name"], @@ -781,8 +801,7 @@ def run_isolated_step( job = client.create_custom_job(parent=parent, custom_job=job_request) monitor_job( job_id=job.name, - credentials_source=self, - client_options=client_options, + get_client=self.get_job_service_client, ) def _upload_and_run_pipeline( @@ -1060,15 +1079,6 @@ def _configure_container_resources( return dynamic_component - def _initialize_vertex_client(self) -> None: - """Initializes the Vertex client.""" - credentials, project_id = self._get_authentication() - aiplatform.init( - project=project_id, - location=self.config.location, - credentials=credentials, - ) - def fetch_status( self, run: "PipelineRunResponse", include_steps: bool = False ) -> Tuple[ @@ -1102,8 +1112,6 @@ def fetch_status( == run.stack.components[StackComponentType.ORCHESTRATOR][0].id ) - self._initialize_vertex_client() - # Fetch the status of the PipelineJob if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata: run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID] @@ -1115,8 +1123,14 @@ def fetch_status( "the status." ) + credentials, project_id = self._get_authentication() if run.snapshot and run.snapshot.is_dynamic: - status = aiplatform.CustomJob.get(run_id).state + status = aiplatform.CustomJob.get( + run_id, + project=project_id, + location=self.config.location, + credentials=credentials, + ).state if status in [ JobState.JOB_STATE_QUEUED, @@ -1143,7 +1157,12 @@ def fetch_status( else: pipeline_status = run.status else: - status = aiplatform.PipelineJob.get(run_id).state + status = aiplatform.PipelineJob.get( + run_id, + project=project_id, + location=self.config.location, + credentials=credentials, + ).state # Map the potential outputs to ZenML ExecutionStatus. Potential values: # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html# diff --git a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py index 3bf1f53c9ae..76cc07cd7f2 100644 --- a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py +++ b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py @@ -57,6 +57,8 @@ class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin): ZenML entrypoint command in it. """ + _job_service_client: Optional[aiplatform.gapic.JobServiceClient] = None + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initializes the step operator and validates the accelerator type. @@ -150,6 +152,25 @@ def get_docker_builds( return builds + def get_job_service_client(self) -> aiplatform.gapic.JobServiceClient: + """Get the job service client. + + Returns: + The job service client. + """ + if self.connector_has_expired(): + self._job_service_client = None + + if self._job_service_client is None: + credentials, _ = self._get_authentication() + client_options = { + "api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX + } + self._job_service_client = aiplatform.gapic.JobServiceClient( + credentials=credentials, client_options=client_options + ) + return self._job_service_client + def launch( self, info: "StepRunInfo", @@ -193,15 +214,10 @@ def launch( ) logger.debug("Vertex AI Job=%s", job_request) - credentials, project_id = self._get_authentication() - client_options = { - "api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX - } - client = aiplatform.gapic.JobServiceClient( - credentials=credentials, client_options=client_options + client = self.get_job_service_client() + parent = ( + f"projects/{self.gcp_project_id}/locations/{self.config.region}" ) - - parent = f"projects/{project_id}/locations/{self.config.region}" logger.info( "Submitting custom job='%s', path='%s' to Vertex AI Training.", job_request["display_name"], @@ -215,6 +231,5 @@ def launch( monitor_job( job_id=response.name, - credentials_source=self, - client_options=client_options, + get_client=self.get_job_service_client, ) diff --git a/src/zenml/integrations/gcp/utils.py b/src/zenml/integrations/gcp/utils.py index 5ab0042c87b..d5464530120 100644 --- a/src/zenml/integrations/gcp/utils.py +++ b/src/zenml/integrations/gcp/utils.py @@ -14,7 +14,7 @@ """Vertex utilities.""" import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from google.api_core.exceptions import ServerError from google.cloud import aiplatform @@ -25,9 +25,6 @@ VERTEX_JOB_STATES_COMPLETED, VERTEX_JOB_STATES_FAILED, ) -from zenml.integrations.gcp.google_credentials_mixin import ( - GoogleCredentialsMixin, -) from zenml.integrations.gcp.vertex_custom_job_parameters import ( VertexCustomJobParameters, ) @@ -55,59 +52,33 @@ def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None: ) -def get_job_service_client( - credentials_source: GoogleCredentialsMixin, - client_options: Optional[Dict[str, Any]] = None, -) -> aiplatform.gapic.JobServiceClient: - """Gets a job service client. - - Args: - credentials_source: The component that provides the credentials to - access the job. - client_options: The client options to use for the job service client. - - Returns: - A job service client. - """ - credentials, _ = credentials_source._get_authentication() - return aiplatform.gapic.JobServiceClient( - credentials=credentials, client_options=client_options - ) - - def monitor_job( job_id: str, - credentials_source: GoogleCredentialsMixin, - client_options: Optional[Dict[str, Any]] = None, + get_client: Callable[[], aiplatform.gapic.JobServiceClient], ) -> None: """Monitors a job until it is completed. Args: job_id: The ID of the job to monitor. - credentials_source: The component that provides the credentials to - access the job. - client_options: The client options to use for the job service client. + get_client: A function that returns an authenticated job service client. Raises: RuntimeError: If the job fails. """ retry_count = 0 - client = get_job_service_client( - credentials_source=credentials_source, client_options=client_options - ) + client = get_client() while True: time.sleep(POLLING_INTERVAL_IN_SECONDS) - if credentials_source.connector_has_expired(): - client = get_job_service_client( - credentials_source=credentials_source, - client_options=client_options, - ) + # Fetch a fresh client in case the credentials have expired + client = get_client() try: response = client.get_custom_job(name=job_id) retry_count = 0 except (ConnectionError, ServerError) as err: + # Retry on connection errors, see also + # https://github.com/googleapis/google-api-python-client/issues/218 if retry_count < CONNECTION_ERROR_RETRY_LIMIT: retry_count += 1 logger.warning(