@@ -63,10 +63,12 @@ class Step(Entity):
6363 Attributes:
6464 name (str): The name of the step.
6565 step_type (StepTypeEnum): The type of the step.
66+ depends_on (List[str]): The list of step names the current step depends on
6667 """
6768
6869 name : str = attr .ib (factory = str )
6970 step_type : StepTypeEnum = attr .ib (factory = StepTypeEnum .factory )
71+ depends_on : List [str ] = attr .ib (default = None )
7072
7173 @property
7274 @abc .abstractmethod
@@ -80,11 +82,22 @@ def properties(self):
8082
8183 def to_request (self ) -> RequestType :
8284 """Gets the request structure for workflow service calls."""
83- return {
85+ request_dict = {
8486 "Name" : self .name ,
8587 "Type" : self .step_type .value ,
8688 "Arguments" : self .arguments ,
8789 }
90+ if self .depends_on :
91+ request_dict ["DependsOn" ] = self .depends_on
92+ return request_dict
93+
94+ def add_depends_on (self , step_names : List [str ]):
95+ """Add step names to the current step depends on list"""
96+ if not step_names :
97+ return
98+ if not self .depends_on :
99+ self .depends_on = []
100+ self .depends_on .extend (step_names )
88101
89102 @property
90103 def ref (self ) -> Dict [str , str ]:
@@ -133,6 +146,7 @@ def __init__(
133146 estimator : EstimatorBase ,
134147 inputs : TrainingInput = None ,
135148 cache_config : CacheConfig = None ,
149+ depends_on : List [str ] = None ,
136150 ):
137151 """Construct a TrainingStep, given an `EstimatorBase` instance.
138152
@@ -144,8 +158,10 @@ def __init__(
144158 estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
145159 inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
146160 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
161+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
162+ depends on
147163 """
148- super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING )
164+ super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING , depends_on )
149165 self .estimator = estimator
150166 self .inputs = inputs
151167 self ._properties = Properties (
@@ -188,10 +204,7 @@ class CreateModelStep(Step):
188204 """CreateModel step for workflow."""
189205
190206 def __init__ (
191- self ,
192- name : str ,
193- model : Model ,
194- inputs : CreateModelInput ,
207+ self , name : str , model : Model , inputs : CreateModelInput , depends_on : List [str ] = None
195208 ):
196209 """Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
197210
@@ -203,8 +216,10 @@ def __init__(
203216 model (Model): A `sagemaker.model.Model` instance.
204217 inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
205218 Defaults to `None`.
219+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
220+ depends on
206221 """
207- super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL )
222+ super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL , depends_on )
208223 self .model = model
209224 self .inputs = inputs or CreateModelInput ()
210225
@@ -247,6 +262,7 @@ def __init__(
247262 transformer : Transformer ,
248263 inputs : TransformInput ,
249264 cache_config : CacheConfig = None ,
265+ depends_on : List [str ] = None ,
250266 ):
251267 """Constructs a TransformStep, given an `Transformer` instance.
252268
@@ -258,8 +274,10 @@ def __init__(
258274 transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
259275 inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
260276 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
277+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
278+ depends on
261279 """
262- super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM )
280+ super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM , depends_on )
263281 self .transformer = transformer
264282 self .inputs = inputs
265283 self .cache_config = cache_config
@@ -320,6 +338,7 @@ def __init__(
320338 code : str = None ,
321339 property_files : List [PropertyFile ] = None ,
322340 cache_config : CacheConfig = None ,
341+ depends_on : List [str ] = None ,
323342 ):
324343 """Construct a ProcessingStep, given a `Processor` instance.
325344
@@ -340,8 +359,10 @@ def __init__(
340359 property_files (List[PropertyFile]): A list of property files that workflow looks
341360 for and resolves from the configured processing output list.
342361 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
362+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
363+ depends on
343364 """
344- super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING )
365+ super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING , depends_on )
345366 self .processor = processor
346367 self .inputs = inputs
347368 self .outputs = outputs
0 commit comments