1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- """Placeholder docstring """
13+ """Utility methods used by framework classes """
1414from __future__ import absolute_import
1515
16- from collections import namedtuple
17-
1816import os
1917import re
2018import shutil
2119import tempfile
20+ from collections import namedtuple
2221
2322import sagemaker .utils
24- from sagemaker .utils import get_ecr_image_uri_prefix , ECR_URI_PATTERN
2523from sagemaker import s3
24+ from sagemaker .utils import get_ecr_image_uri_prefix , ECR_URI_PATTERN
2625
2726_TAR_SOURCE_FILENAME = "source.tar.gz"
2827
6968 "tensorflow-serving-eia" : "tensorflow-inference-eia" ,
7069 "mxnet" : "mxnet-training" ,
7170 "mxnet-serving" : "mxnet-inference" ,
71+ "mxnet-serving-eia" : "mxnet-inference-eia" ,
7272 "pytorch" : "pytorch-training" ,
7373 "pytorch-serving" : "pytorch-inference" ,
74- "mxnet-serving-eia" : "mxnet-inference-eia" ,
7574}
7675
7776MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
78- "tensorflow-scriptmode" : [1 , 13 , 1 ],
77+ "tensorflow-scriptmode" : { "py3" : [1 , 13 , 1 ], "py2" : [ 1 , 14 , 0 ]} ,
7978 "tensorflow-serving" : [1 , 13 , 0 ],
8079 "tensorflow-serving-eia" : [1 , 14 , 0 ],
81- "mxnet" : [1 , 4 , 1 ],
82- "mxnet-serving" : [1 , 4 , 1 ],
80+ "mxnet" : {"py3" : [1 , 4 , 1 ], "py2" : [1 , 6 , 0 ]},
81+ "mxnet-serving" : {"py3" : [1 , 4 , 1 ], "py2" : [1 , 6 , 0 ]},
82+ "mxnet-serving-eia" : [1 , 4 , 1 ],
8383 "pytorch" : [1 , 2 , 0 ],
8484 "pytorch-serving" : [1 , 2 , 0 ],
85- "mxnet-serving-eia" : [1 , 4 , 1 ],
8685}
8786
8887
8988def is_version_equal_or_higher (lowest_version , framework_version ):
9089 """Determine whether the ``framework_version`` is equal to or higher than
9190 ``lowest_version``
91+
9292 Args:
9393 lowest_version (List[int]): lowest version represented in an integer
9494 list
9595 framework_version (str): framework version string
96+
9697 Returns:
97- bool: Whether or not framework_version is equal to or higher than
98- lowest_version
98+ bool: Whether or not `` framework_version`` is equal to or higher than
99+ `` lowest_version``
99100 """
100101 version_list = [int (s ) for s in framework_version .split ("." )]
101102 return version_list >= lowest_version [0 : len (version_list )]
102103
103104
104- def _is_merged_versions (framework , framework_version ):
105- """
105+ def _is_dlc_version (framework , framework_version , py_version ):
106+ """Return if the framework's version uses the corresponding DLC image.
107+
106108 Args:
107- framework:
108- framework_version:
109+ framework (str): The framework name, e.g. "tensorflow-scriptmode"
110+ framework_version (str): The framework version
111+ py_version (str): The Python version, e.g. "py3"
112+
113+ Returns:
114+ bool: Whether or not the framework's version uses the DLC image.
109115 """
110116 lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS .get (framework )
117+ if isinstance (lowest_version_list , dict ):
118+ lowest_version_list = lowest_version_list [py_version ]
119+
111120 if lowest_version_list :
112121 return is_version_equal_or_higher (lowest_version_list , framework_version )
113122 return False
114123
115124
116- def _using_merged_images (region , framework , py_version , framework_version ):
117- """
118- Args:
119- region:
120- framework:
121- py_version:
122- accelerator_type:
123- framework_version:
124- """
125- is_gov_region = region in VALID_ACCOUNTS_BY_REGION
126- not_py2 = py_version == "py3" or py_version is None
127- is_merged_versions = _is_merged_versions (framework , framework_version )
128-
129- return (
130- ((not is_gov_region ) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION )
131- and is_merged_versions
132- # TODO: should be not mxnet-1.14.1-py2 instead?
133- and (
134- not_py2
135- or _is_tf_14_or_later (framework , framework_version )
136- or _is_pt_12_or_later (framework , framework_version )
137- or _is_mxnet_16_or_later (framework , framework_version )
138- )
139- )
140-
125+ def _use_dlc_image (region , framework , py_version , framework_version ):
126+ """Return if the DLC image should be used for the given framework,
127+ framework version, Python version, and region.
141128
142- def _is_tf_14_or_later (framework , framework_version ):
143- """
144129 Args:
145- framework:
146- framework_version:
147- """
148- # Asimov team now owns Tensorflow 1.14.0 py2 and py3
149- asimov_lowest_tf_py2 = [1 , 14 , 0 ]
150- version = [int (s ) for s in framework_version .split ("." )]
151- return (
152- framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2 [0 : len (version )]
153- )
154-
130+ region (str): The AWS region.
131+ framework (str): The framework name, e.g. "tensorflow-scriptmode".
132+ py_version (str): The Python version, e.g. "py3".
133+ framework_version (str): The framework version.
155134
156- def _is_pt_12_or_later (framework , framework_version ):
157- """
158- Args:
159- framework: Name of the frameowork
160- framework_version: framework version
135+ Returns:
136+ bool: Whether or not to use the corresponding DLC image.
161137 """
162- # Asimov team now owns PyTorch 1.2.0 py2 and py3
163- asimov_lowest_pt = [1 , 2 , 0 ]
164- version = [int (s ) for s in framework_version .split ("." )]
165- is_pytorch = framework in ("pytorch" , "pytorch-serving" )
166- return is_pytorch and version >= asimov_lowest_pt [0 : len (version )]
138+ is_gov_region = region in VALID_ACCOUNTS_BY_REGION
139+ is_dlc_version = _is_dlc_version (framework , framework_version , py_version )
167140
168-
169- def _is_mxnet_16_or_later (framework , framework_version ):
170- """
171- Args:
172- framework: Name of the frameowork
173- framework_version: framework version
174- """
175- # Asimov team now owns MXNet 1.6.0 py2 and py3
176- asimov_lowest_pt = [1 , 6 , 0 ]
177- version = [int (s ) for s in framework_version .split ("." )]
178- is_mxnet = framework in ("mxnet" , "mxnet-serving" )
179- return is_mxnet and version >= asimov_lowest_pt [0 : len (version )]
141+ return ((not is_gov_region ) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION ) and is_dlc_version
180142
181143
182144def _registry_id (region , framework , py_version , account , framework_version ):
183- """
145+ """Return the Amazon ECR registry number (or AWS account ID) for
146+ the given framework, framework version, Python version, and region.
147+
184148 Args:
185- region:
186- framework:
187- py_version:
188- account:
189- accelerator_type:
190- framework_version:
149+ region (str): The AWS region.
150+ framework (str): The framework name, e.g. "tensorflow-scriptmode".
151+ py_version (str): The Python version, e.g. "py3".
152+ account (str): The AWS account ID to use as a default.
153+ framework_version (str): The framework version.
154+
155+ Returns:
156+ str: The appropriate Amazon ECR registry number. If there is no
157+ specific one for the framework, framework version, Python version,
158+ and region, then ``account`` is returned.
191159 """
192- if _using_merged_images (region , framework , py_version , framework_version ):
160+ if _use_dlc_image (region , framework , py_version , framework_version ):
193161 if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION :
194162 return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION .get (region )
195163 if region in ASIMOV_VALID_ACCOUNTS_BY_REGION :
@@ -211,6 +179,7 @@ def create_image_uri(
211179 optimized_families = None ,
212180):
213181 """Return the ECR URI of an image.
182+
214183 Args:
215184 region (str): AWS region where the image is uploaded.
216185 framework (str): framework used by the image.
@@ -225,6 +194,7 @@ def create_image_uri(
225194 accelerator_type (str): SageMaker Elastic Inference accelerator type.
226195 optimized_families (str): Instance families for which there exist
227196 specific optimized images.
197+
228198 Returns:
229199 str: The appropriate image URI based on the given parameters.
230200 """
@@ -240,7 +210,7 @@ def create_image_uri(
240210 ):
241211 framework += "-eia"
242212
243- # Handle Account Number for Gov Cloud and frameworks with DLC merged images
213+ # Handle account number for specific cases (e.g. GovCloud, opt-in regions, DLC images etc.)
244214 if account is None :
245215 account = _registry_id (
246216 region = region ,
@@ -271,18 +241,19 @@ def create_image_uri(
271241 else :
272242 device_type = "cpu"
273243
274- using_merged_images = _using_merged_images (region , framework , py_version , framework_version )
244+ use_dlc_image = _use_dlc_image (region , framework , py_version , framework_version )
275245
276- if not py_version or (using_merged_images and framework == "tensorflow-serving-eia" ):
246+ if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia" ):
277247 tag = "{}-{}" .format (framework_version , device_type )
278248 else :
279249 tag = "{}-{}-{}" .format (framework_version , device_type , py_version )
280250
281- if using_merged_images :
282- return "{}/{}:{}" .format (
283- get_ecr_image_uri_prefix (account , region ), MERGED_FRAMEWORKS_REPO_MAP [framework ], tag
284- )
285- return "{}/sagemaker-{}:{}" .format (get_ecr_image_uri_prefix (account , region ), framework , tag )
251+ if use_dlc_image :
252+ ecr_repo = MERGED_FRAMEWORKS_REPO_MAP [framework ]
253+ else :
254+ ecr_repo = "sagemaker-{}" .format (framework )
255+
256+ return "{}/{}:{}" .format (get_ecr_image_uri_prefix (account , region ), ecr_repo , tag )
286257
287258
288259def _accelerator_type_valid_for_framework (
0 commit comments