4040from sagemaker .workflow .pipeline_experiment_config import PipelineExperimentConfig
4141from sagemaker .workflow .parallelism_config import ParallelismConfiguration
4242from sagemaker .workflow .properties import Properties
43+ from sagemaker .workflow .selective_execution_config import SelectiveExecutionConfig
4344from sagemaker .workflow .steps import Step , StepTypeEnum
4445from sagemaker .workflow .step_collections import StepCollection
4546from sagemaker .workflow .condition_step import ConditionStep
@@ -312,6 +313,7 @@ def start(
312313 execution_display_name : str = None ,
313314 execution_description : str = None ,
314315 parallelism_config : ParallelismConfiguration = None ,
316+ selective_execution_config : SelectiveExecutionConfig = None ,
315317 ):
316318 """Starts a Pipeline execution in the Workflow service.
317319
@@ -323,16 +325,26 @@ def start(
323325 parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
324326 that is applied to each of the executions of the pipeline. It takes precedence
325327 over the parallelism configuration of the parent pipeline.
328+ selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
329+ selective step execution.
326330
327331 Returns:
328332 A `_PipelineExecution` instance, if successful.
329333 """
334+ if selective_execution_config is not None :
335+ if selective_execution_config .source_pipeline_execution_arn is None :
336+ selective_execution_config .source_pipeline_execution_arn = (
337+ self ._get_latest_execution_arn ()
338+ )
339+ selective_execution_config = selective_execution_config .to_request ()
340+
330341 kwargs = dict (PipelineName = self .name )
331342 update_args (
332343 kwargs ,
333344 PipelineExecutionDescription = execution_description ,
334345 PipelineExecutionDisplayName = execution_display_name ,
335346 ParallelismConfiguration = parallelism_config ,
347+ SelectiveExecutionConfig = selective_execution_config ,
336348 )
337349 if self .sagemaker_session .local_mode :
338350 update_args (kwargs , PipelineParameters = parameters )
@@ -388,6 +400,57 @@ def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
388400 )
389401 self ._interpolate_step_collection_name_in_depends_on (sub_step_requests )
390402
403+ def list_executions (
404+ self ,
405+ sort_by : str = None ,
406+ sort_order : str = None ,
407+ max_results : int = None ,
408+ next_token : str = None ,
409+ ) -> Dict [str , Any ]:
410+ """Lists a pipeline's executions.
411+
412+ Args:
413+ sort_by (str): The field by which to sort results(CreationTime/PipelineExecutionArn).
414+ sort_order (str): The sort order for results (Ascending/Descending).
415+ max_results (int): The maximum number of pipeline executions to return in the response.
416+ next_token (str): If the result of the previous ListPipelineExecutions request was
417+ truncated, the response includes a NextToken. To retrieve the next set of pipeline
418+ executions, use the token in the next request.
419+
420+ Returns:
421+ List of Pipeline Execution Summaries. See
422+ boto3 client list_pipeline_executions
423+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_pipeline_executions
424+ """
425+ kwargs = dict (PipelineName = self .name )
426+ update_args (
427+ kwargs ,
428+ SortBy = sort_by ,
429+ SortOrder = sort_order ,
430+ NextToken = next_token ,
431+ MaxResults = max_results ,
432+ )
433+ response = self .sagemaker_session .sagemaker_client .list_pipeline_executions (** kwargs )
434+
435+ # Return only PipelineExecutionSummaries and NextToken from the list_pipeline_executions
436+ # response
437+ return {
438+ key : response [key ]
439+ for key in ["PipelineExecutionSummaries" , "NextToken" ]
440+ if key in response
441+ }
442+
443+ def _get_latest_execution_arn (self ):
444+ """Retrieves the latest execution of this pipeline"""
445+ response = self .list_executions (
446+ sort_by = "CreationTime" ,
447+ sort_order = "Descending" ,
448+ max_results = 1 ,
449+ )
450+ if response ["PipelineExecutionSummaries" ]:
451+ return response ["PipelineExecutionSummaries" ][0 ]["PipelineExecutionArn" ]
452+ return None
453+
391454
392455def format_start_parameters (parameters : Dict [str , Any ]) -> List [Dict [str , Any ]]:
393456 """Formats start parameter overrides as a list of dicts.
0 commit comments