@@ -155,6 +155,8 @@ def __init__(
155155 entry_point : Optional [Union [str , PipelineVariable ]] = None ,
156156 dependencies : Optional [List [Union [str ]]] = None ,
157157 instance_groups : Optional [List [InstanceGroup ]] = None ,
158+ training_repository_access_mode : Optional [Union [str , PipelineVariable ]] = None ,
159+ training_repository_credentials_provider_arn : Optional [Union [str , PipelineVariable ]] = None ,
158160 ** kwargs ,
159161 ):
160162 """Initialize an ``EstimatorBase`` instance.
@@ -489,6 +491,18 @@ def __init__(
489491 `Train Using a Heterogeneous Cluster
490492 <https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
491493 in the *Amazon SageMaker developer guide*.
494+ training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
495+ Docker image that contains the training algorithm (default: None).
496+ Set this to one of the following values:
497+ * 'Platform' - The training image is hosted in Amazon ECR.
498+ * 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
499+ When it's default to None, its behavior will be same as 'Platform' - image is hosted
500+ in ECR.
501+ training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
502+ (ARN) of an AWS Lambda function that provides credentials to authenticate to the
503+ private Docker registry where your training image is hosted (default: None).
504+ When it's set to None, SageMaker will not do authentication before pulling the image
505+ in the private Docker registry.
492506 """
493507 instance_count = renamed_kwargs (
494508 "train_instance_count" , "instance_count" , instance_count , kwargs
@@ -536,7 +550,9 @@ def __init__(
536550 self .dependencies = dependencies or []
537551 self .uploaded_code = None
538552 self .tags = add_jumpstart_tags (
539- tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
553+ tags = tags ,
554+ training_model_uri = self .model_uri ,
555+ training_script_uri = self .source_dir ,
540556 )
541557 if self .instance_type in ("local" , "local_gpu" ):
542558 if self .instance_type == "local_gpu" and self .instance_count > 1 :
@@ -571,6 +587,12 @@ def __init__(
571587 self .subnets = subnets
572588 self .security_group_ids = security_group_ids
573589
590+ # training image configs
591+ self .training_repository_access_mode = training_repository_access_mode
592+ self .training_repository_credentials_provider_arn = (
593+ training_repository_credentials_provider_arn
594+ )
595+
574596 self .encrypt_inter_container_traffic = encrypt_inter_container_traffic
575597 self .use_spot_instances = use_spot_instances
576598 self .max_wait = max_wait
@@ -651,7 +673,8 @@ def _ensure_base_job_name(self):
651673 self .base_job_name
652674 or get_jumpstart_base_name_if_jumpstart_model (self .source_dir , self .model_uri )
653675 or base_name_from_image (
654- self .training_image_uri (), default_base_name = EstimatorBase .JOB_CLASS_NAME
676+ self .training_image_uri (),
677+ default_base_name = EstimatorBase .JOB_CLASS_NAME ,
655678 )
656679 )
657680
@@ -1405,7 +1428,10 @@ def deploy(
14051428 self ._ensure_base_job_name ()
14061429
14071430 jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model (
1408- kwargs .get ("source_dir" ), self .source_dir , kwargs .get ("model_data" ), self .model_uri
1431+ kwargs .get ("source_dir" ),
1432+ self .source_dir ,
1433+ kwargs .get ("model_data" ),
1434+ self .model_uri ,
14091435 )
14101436 default_name = (
14111437 name_from_base (jumpstart_base_name )
@@ -1638,6 +1664,15 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
16381664 init_params ["algorithm_arn" ] = job_details ["AlgorithmSpecification" ]["AlgorithmName" ]
16391665 elif "TrainingImage" in job_details ["AlgorithmSpecification" ]:
16401666 init_params ["image_uri" ] = job_details ["AlgorithmSpecification" ]["TrainingImage" ]
1667+ if "TrainingImageConfig" in job_details ["AlgorithmSpecification" ]:
1668+ init_params ["training_repository_access_mode" ] = job_details [
1669+ "AlgorithmSpecification"
1670+ ]["TrainingImageConfig" ].get ("TrainingRepositoryAccessMode" )
1671+ init_params ["training_repository_credentials_provider_arn" ] = (
1672+ job_details ["AlgorithmSpecification" ]["TrainingImageConfig" ]
1673+ .get ("TrainingRepositoryAuthConfig" , {})
1674+ .get ("TrainingRepositoryCredentialsProviderArn" )
1675+ )
16411676 else :
16421677 raise RuntimeError (
16431678 "Invalid AlgorithmSpecification. Either TrainingImage or "
@@ -2118,6 +2153,17 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
21182153 else :
21192154 train_args ["retry_strategy" ] = None
21202155
2156+ if estimator .training_repository_access_mode is not None :
2157+ training_image_config = {
2158+ "TrainingRepositoryAccessMode" : estimator .training_repository_access_mode
2159+ }
2160+ if estimator .training_repository_credentials_provider_arn is not None :
2161+ training_image_config ["TrainingRepositoryAuthConfig" ] = {}
2162+ training_image_config ["TrainingRepositoryAuthConfig" ][
2163+ "TrainingRepositoryCredentialsProviderArn"
2164+ ] = estimator .training_repository_credentials_provider_arn
2165+ train_args ["training_image_config" ] = training_image_config
2166+
21212167 # encrypt_inter_container_traffic may be a pipeline variable place holder object
21222168 # which is parsed in execution time
21232169 if estimator .encrypt_inter_container_traffic :
@@ -2182,7 +2228,11 @@ def _is_local_channel(cls, input_uri):
21822228
21832229 @classmethod
21842230 def update (
2185- cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2231+ cls ,
2232+ estimator ,
2233+ profiler_rule_configs = None ,
2234+ profiler_config = None ,
2235+ resource_config = None ,
21862236 ):
21872237 """Update a running Amazon SageMaker training job.
21882238
@@ -2321,6 +2371,8 @@ def __init__(
23212371 entry_point : Optional [Union [str , PipelineVariable ]] = None ,
23222372 dependencies : Optional [List [str ]] = None ,
23232373 instance_groups : Optional [List [InstanceGroup ]] = None ,
2374+ training_repository_access_mode : Optional [Union [str , PipelineVariable ]] = None ,
2375+ training_repository_credentials_provider_arn : Optional [Union [str , PipelineVariable ]] = None ,
23242376 ** kwargs ,
23252377 ):
23262378 """Initialize an ``Estimator`` instance.
@@ -2654,6 +2706,18 @@ def __init__(
26542706 `Train Using a Heterogeneous Cluster
26552707 <https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
26562708 in the *Amazon SageMaker developer guide*.
2709+ training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
2710+ Docker image that contains the training algorithm (default: None).
2711+ Set this to one of the following values:
2712+ * 'Platform' - The training image is hosted in Amazon ECR.
2713+ * 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
2714+ When it's default to None, its behavior will be same as 'Platform' - image is hosted
2715+ in ECR.
2716+ training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
2717+ (ARN) of an AWS Lambda function that provides credentials to authenticate to the
2718+ private Docker registry where your training image is hosted (default: None).
2719+ When it's set to None, SageMaker will not do authentication before pulling the image
2720+ in the private Docker registry.
26572721 """
26582722 self .image_uri = image_uri
26592723 self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -2698,6 +2762,8 @@ def __init__(
26982762 dependencies = dependencies ,
26992763 hyperparameters = hyperparameters ,
27002764 instance_groups = instance_groups ,
2765+ training_repository_access_mode = training_repository_access_mode ,
2766+ training_repository_credentials_provider_arn = training_repository_credentials_provider_arn , # noqa: E501 # pylint: disable=line-too-long
27012767 ** kwargs ,
27022768 )
27032769
0 commit comments