@@ -77,6 +77,17 @@ class _ExecutionVariable:
7777 name : str
7878
7979
80+ @dataclass
81+ class _S3BaseUriIdentifier :
82+ """Identifies that the class refers to function step s3 base uri.
83+
84+ The s3_base_uri = s3_root_uri + pipeline_name.
85+ This identifier is resolved in function step runtime by SDK.
86+ """
87+
88+ NAME = "S3_BASE_URI"
89+
90+
8091@dataclass
8192class _DelayedReturn :
8293 """Delayed return from a function."""
@@ -155,6 +166,7 @@ def __init__(
155166 hmac_key : str ,
156167 parameter_resolver : _ParameterResolver ,
157168 execution_variable_resolver : _ExecutionVariableResolver ,
169+ s3_base_uri : str ,
158170 ** settings ,
159171 ):
160172 """Resolve delayed return.
@@ -164,8 +176,12 @@ def __init__(
164176 hmac_key: key used to encrypt serialized and deserialized function and arguments.
165177 parameter_resolver: resolver used to pipeline parameters.
166178 execution_variable_resolver: resolver used to resolve execution variables.
179+ s3_base_uri (str): the s3 base uri of the function step that
180+ the serialized artifacts will be uploaded to.
181+ The s3_base_uri = s3_root_uri + pipeline_name.
167182 **settings: settings to pass to the deserialization function.
168183 """
184+ self ._s3_base_uri = s3_base_uri
169185 self ._parameter_resolver = parameter_resolver
170186 self ._execution_variable_resolver = execution_variable_resolver
171187 # different delayed returns can have the same uri, so we need to dedupe
@@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
205221 uri .append (self ._parameter_resolver .resolve (component ))
206222 elif isinstance (component , _ExecutionVariable ):
207223 uri .append (self ._execution_variable_resolver .resolve (component ))
224+ elif isinstance (component , _S3BaseUriIdentifier ):
225+ uri .append (self ._s3_base_uri )
208226 else :
209227 uri .append (component )
210228 return s3_path_join (* uri )
@@ -219,7 +237,12 @@ def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any):
219237
220238
221239def resolve_pipeline_variables (
222- context : Context , func_args : Tuple , func_kwargs : Dict , hmac_key : str , ** settings
240+ context : Context ,
241+ func_args : Tuple ,
242+ func_kwargs : Dict ,
243+ hmac_key : str ,
244+ s3_base_uri : str ,
245+ ** settings ,
223246):
224247 """Resolve pipeline variables.
225248
@@ -228,6 +251,8 @@ def resolve_pipeline_variables(
228251 func_args: function args.
229252 func_kwargs: function kwargs.
230253 hmac_key: key used to encrypt serialized and deserialized function and arguments.
254+ s3_base_uri: the s3 base uri of the function step that the serialized artifacts
255+ will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
231256 **settings: settings to pass to the deserialization function.
232257 """
233258
@@ -251,6 +276,7 @@ def resolve_pipeline_variables(
251276 hmac_key = hmac_key ,
252277 parameter_resolver = parameter_resolver ,
253278 execution_variable_resolver = execution_variable_resolver ,
279+ s3_base_uri = s3_base_uri ,
254280 ** settings ,
255281 )
256282
@@ -289,11 +315,10 @@ def resolve_pipeline_variables(
289315 return resolved_func_args , resolved_func_kwargs
290316
291317
292- def convert_pipeline_variables_to_pickleable (s3_base_uri : str , func_args : Tuple , func_kwargs : Dict ):
318+ def convert_pipeline_variables_to_pickleable (func_args : Tuple , func_kwargs : Dict ):
293319 """Convert pipeline variables to pickleable.
294320
295321 Args:
296- s3_base_uri: s3 base uri where artifacts are stored.
297322 func_args: function args.
298323 func_kwargs: function kwargs.
299324 """
@@ -304,11 +329,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,
304329
305330 from sagemaker .workflow .function_step import DelayedReturn
306331
332+ # Notes:
333+ # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334+ # when defining function steps. After step-level arg serialization,
335+ # it's hard to update the s3_base_uri in pipeline compile time.
336+ # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337+ # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338+ # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339+ # should be retrieved from the pipeline's sagemaker_session.
307340 def convert (arg ):
308341 if isinstance (arg , DelayedReturn ):
309342 return _DelayedReturn (
310343 uri = [
311- s3_base_uri ,
344+ _S3BaseUriIdentifier () ,
312345 ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
313346 arg ._step .name ,
314347 "results" ,
0 commit comments