@@ -35,6 +35,7 @@ def retrieve(
3535 accelerator_type = None ,
3636 image_scope = None ,
3737 container_version = None ,
38+ distribution = None ,
3839):
3940 """Retrieves the ECR URI for the Docker image matching the given arguments.
4041
@@ -54,6 +55,8 @@ def retrieve(
5455 Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
5556 ``image_scope`` is ignored.
5657 container_version (str): the version of docker image
58+ distribution (dict): A dictionary with information on how to run distributed training
59+ (default: None).
5760
5861 Returns:
5962 str: the ECR URI for the corresponding SageMaker Docker image.
@@ -77,10 +80,25 @@ def retrieve(
7780 processor = _processor (
7881 instance_type , config .get ("processors" ) or version_config .get ("processors" )
7982 )
83+
8084 tag = _format_tag (
81- version_config .get ("tag_prefix" , version ), processor , py_version , container_version
85+ version_config .get ("tag_prefix" , version ),
86+ processor ,
87+ py_version ,
88+ container_version ,
8289 )
8390
91+ if _should_auto_select_container_version (instance_type , distribution ):
92+ container_versions = {
93+ "tensorflow-2.3-gpu-py37" : "cu110-ubuntu18.04-v3" ,
94+ "tensorflow-1.15-gpu-py37" : "cu110-ubuntu18.04-v8" ,
95+ "mxnet-1.8-gpu-py37" : "cu110-ubuntu16.04-v1" ,
96+ "pytorch-1.6-gpu-py36" : "cu110-ubuntu18.04-v3" ,
97+ }
98+ key = "-" .join ([framework , tag ])
99+ if key in container_versions :
100+ tag = "-" .join ([tag , container_versions [key ]])
101+
84102 if tag :
85103 repo += ":{}" .format (tag )
86104
@@ -217,6 +235,23 @@ def _processor(instance_type, available_processors):
217235 return processor
218236
219237
238+ def _should_auto_select_container_version (instance_type , distribution ):
239+ """Returns a boolean that indicates whether to use an auto-selected container version."""
240+ p4d = False
241+ if instance_type :
242+ # looks for either "ml.<family>.<size>" or "ml_<family>"
243+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
244+ if match :
245+ family = match [1 ]
246+ p4d = family == "p4d"
247+
248+ smdistributed = False
249+ if distribution :
250+ smdistributed = "smdistributed" in distribution
251+
252+ return p4d or smdistributed
253+
254+
220255def _validate_py_version_and_set_if_needed (py_version , version_config , framework ):
221256 """Checks if the Python version is one of the supported versions."""
222257 if "repository" in version_config :
0 commit comments