3434XGBOOST_FRAMEWORK = "xgboost"
3535SKLEARN_FRAMEWORK = "sklearn"
3636TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
37+ INFERENCE_GRAVITON = "inference_graviton"
3738
3839
3940@override_pipeline_parameter_var
@@ -75,8 +76,8 @@ def retrieve(
7576 accelerator_type (str): Elastic Inference accelerator type. For more, see
7677 https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
7778 image_scope (str): The image type, i.e. what it is used for.
78- Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
79- ``image_scope`` is ignored.
79+ Valid values: "training", "inference", "inference_graviton", " eia".
80+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
8081 container_version (str): the version of docker image.
8182 Ideally the value of parameter should be created inside the framework.
8283 For custom use, see the list of supported container versions:
@@ -146,8 +147,9 @@ def retrieve(
146147 )
147148
148149 if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK ):
150+ final_image_scope = image_scope
149151 config = _config_for_framework_and_scope (
150- framework + "-training-compiler" , image_scope , accelerator_type
152+ framework + "-training-compiler" , final_image_scope , accelerator_type
151153 )
152154 else :
153155 _framework = framework
@@ -234,6 +236,7 @@ def retrieve(
234236 tag = _get_image_tag (
235237 container_version ,
236238 distribution ,
239+ final_image_scope ,
237240 framework ,
238241 inference_tool ,
239242 instance_type ,
@@ -266,6 +269,7 @@ def _get_instance_type_family(instance_type):
266269def _get_image_tag (
267270 container_version ,
268271 distribution ,
272+ final_image_scope ,
269273 framework ,
270274 inference_tool ,
271275 instance_type ,
@@ -276,20 +280,29 @@ def _get_image_tag(
276280):
277281 """Return image tag based on framework, container, and compute configuration(s)."""
278282 instance_type_family = _get_instance_type_family (instance_type )
279- if (
280- framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK )
281- and instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
282- ):
283- version_to_arm64_tag_mapping = {
284- "xgboost" : {
285- "1.5-1" : "1.5-1-arm64" ,
286- "1.3-1" : "1.3-1-arm64" ,
287- },
288- "sklearn" : {
289- "1.0-1" : "1.0-1-arm64-cpu-py3" ,
290- },
291- }
292- tag = version_to_arm64_tag_mapping [framework ][version ]
283+ if framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
284+ if instance_type_family and final_image_scope == INFERENCE_GRAVITON :
285+ _validate_arg (
286+ instance_type_family ,
287+ GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY ,
288+ "instance type" ,
289+ )
290+ if (
291+ instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
292+ or final_image_scope == INFERENCE_GRAVITON
293+ ):
294+ version_to_arm64_tag_mapping = {
295+ "xgboost" : {
296+ "1.5-1" : "1.5-1-arm64" ,
297+ "1.3-1" : "1.3-1-arm64" ,
298+ },
299+ "sklearn" : {
300+ "1.0-1" : "1.0-1-arm64-cpu-py3" ,
301+ },
302+ }
303+ tag = version_to_arm64_tag_mapping [framework ][version ]
304+ else :
305+ tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
293306 else :
294307 tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
295308
@@ -375,7 +388,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
375388 framework in GRAVITON_ALLOWED_FRAMEWORKS
376389 and _get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
377390 ):
378- return "inference_graviton"
391+ return INFERENCE_GRAVITON
379392 if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
380393 # Preserves backwards compatibility with XGB/SKLearn configs which no
381394 # longer define top-level "scope" keys after introducing support for
0 commit comments