|
13 | 13 | """Utilities to support workflow.""" |
14 | 14 | from __future__ import absolute_import |
15 | 15 |
|
| 16 | +import inspect |
| 17 | +import logging |
| 18 | +from functools import wraps |
16 | 19 | from pathlib import Path |
17 | | -from typing import List, Sequence, Union, Set |
| 20 | +from typing import List, Sequence, Union, Set, TYPE_CHECKING |
18 | 21 | import hashlib |
19 | 22 | from urllib.parse import unquote, urlparse |
20 | 23 | from _hashlib import HASH as Hash |
21 | 24 |
|
| 25 | +from sagemaker.workflow.parameters import Parameter |
22 | 26 | from sagemaker.workflow.pipeline_context import _StepArguments |
23 | | -from sagemaker.workflow.step_collections import StepCollection |
24 | 27 | from sagemaker.workflow.entities import ( |
25 | 28 | Entity, |
26 | 29 | RequestType, |
27 | 30 | ) |
28 | 31 |
|
| 32 | +if TYPE_CHECKING: |
| 33 | + from sagemaker.workflow.step_collections import StepCollection |
| 34 | + |
29 | 35 | BUF_SIZE = 65536 # 64KiB |
30 | 36 |
|
31 | 37 |
|
32 | | -def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[RequestType]: |
| 38 | +def list_to_request(entities: Sequence[Union[Entity, "StepCollection"]]) -> List[RequestType]: |
33 | 39 | """Get the request structure for list of entities. |
34 | 40 |
|
35 | 41 | Args: |
36 | 42 | entities (Sequence[Entity]): A list of entities. |
37 | 43 | Returns: |
38 | 44 | list: A request structure for a workflow service call. |
39 | 45 | """ |
| 46 | + from sagemaker.workflow.step_collections import StepCollection |
| 47 | + |
40 | 48 | request_dicts = [] |
41 | 49 | for entity in entities: |
42 | 50 | if isinstance(entity, Entity): |
@@ -151,3 +159,41 @@ def validate_step_args_input( |
151 | 159 | raise TypeError(error_message) |
152 | 160 | if step_args.caller_name not in expected_caller: |
153 | 161 | raise ValueError(error_message) |
| 162 | + |
| 163 | + |
| 164 | +def override_pipeline_parameter_var(func): |
| 165 | + """A decorator to override pipeline Parameters passed into a function |
| 166 | +
|
| 167 | + This is a temporary decorator to override pipeline Parameter objects with their default value |
| 168 | + and display warning information to instruct users to update their code. |
| 169 | +
|
| 170 | + This decorator can help to give a grace period for users to update their code when |
| 171 | + we make changes to explicitly prevent passing any pipeline variables to a function. |
| 172 | +
|
| 173 | + We should remove this decorator after the grace period. |
| 174 | + """ |
| 175 | + warning_msg_template = ( |
| 176 | + "%s should not be a pipeline variable (%s). " |
| 177 | + "The default_value of this Parameter object will be used to override it. " |
| 178 | + "Please remove this pipeline variable and use python primitives instead." |
| 179 | + ) |
| 180 | + |
| 181 | + @wraps(func) |
| 182 | + def wrapper(*args, **kwargs): |
| 183 | + params = inspect.signature(func).parameters |
| 184 | + args = list(args) |
| 185 | + for i, (arg_name, _) in enumerate(params.items()): |
| 186 | + if i >= len(args): |
| 187 | + break |
| 188 | + if isinstance(args[i], Parameter): |
| 189 | + logging.warning(warning_msg_template, arg_name, type(args[i])) |
| 190 | + args[i] = args[i].default_value |
| 191 | + args = tuple(args) |
| 192 | + |
| 193 | + for arg_name, value in kwargs.items(): |
| 194 | + if isinstance(value, Parameter): |
| 195 | + logging.warning(warning_msg_template, arg_name, type(value)) |
| 196 | + kwargs[arg_name] = value.default_value |
| 197 | + return func(*args, **kwargs) |
| 198 | + |
| 199 | + return wrapper |
0 commit comments