@@ -145,6 +145,7 @@ def __init__(
145145 code_location : Optional [str ] = None ,
146146 entry_point : Optional [str ] = None ,
147147 dependencies : Optional [List [Union [str ]]] = None ,
148+ instance_groups = None ,
148149 ** kwargs ,
149150 ):
150151 """Initialize an ``EstimatorBase`` instance.
@@ -156,9 +157,10 @@ def __init__(
156157 artifacts. After the endpoint is created, the inference code
157158 might use the IAM role, if it needs to access an AWS resource.
158159 instance_count (int): Number of Amazon EC2 instances to use
159- for training.
160+ for training. Required if instance_groups is not set.
160161 instance_type (str): Type of EC2 instance to use for training,
161- for example, 'ml.c4.xlarge'.
162+ for example, 'ml.c4.xlarge'. Required if instance_groups is
163+ not set.
162164 volume_size (int): Size in GB of the EBS volume to use for
163165 storing input data during training (default: 30). Must be large
164166 enough to store training data if File Mode is used (which is the
@@ -424,7 +426,10 @@ def __init__(
424426 >>> |------ virtual-env
425427
426428 This is not supported with "local code" in Local Mode.
427-
429+ instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup
430+ for specifying different instance groups for heterogeneous cluster.
431+ For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64),
432+ sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)]
428433 """
429434 instance_count = renamed_kwargs (
430435 "train_instance_count" , "instance_count" , instance_count , kwargs
@@ -442,12 +447,10 @@ def __init__(
442447 "train_volume_kms_key" , "volume_kms_key" , volume_kms_key , kwargs
443448 )
444449
445- if instance_count is None or instance_type is None :
446- raise ValueError ("Both instance_count and instance_type are required." )
447-
448450 self .role = role
449451 self .instance_count = instance_count
450452 self .instance_type = instance_type
453+ self .instance_groups = instance_groups
451454 self .volume_size = volume_size
452455 self .volume_kms_key = volume_kms_key
453456 self .max_run = max_run
@@ -2103,6 +2106,7 @@ def __init__(
21032106 code_location : Optional [str ] = None ,
21042107 entry_point : Optional [str ] = None ,
21052108 dependencies : Optional [List [str ]] = None ,
2109+ instance_groups = None ,
21062110 ** kwargs ,
21072111 ):
21082112 """Initialize an ``Estimator`` instance.
@@ -2115,9 +2119,10 @@ def __init__(
21152119 artifacts. After the endpoint is created, the inference code
21162120 might use the IAM role, if it needs to access an AWS resource.
21172121 instance_count (int): Number of Amazon EC2 instances to use
2118- for training.
2122+ for training. Required if instance_groups is not set.
21192123 instance_type (str): Type of EC2 instance to use for training,
2120- for example, 'ml.c4.xlarge'.
2124+ for example, 'ml.c4.xlarge'. Required if instance_groups is
2125+ not set.
21212126 volume_size (int): Size in GB of the EBS volume to use for
21222127 storing input data during training (default: 30). Must be large
21232128 enough to store training data if File Mode is used (which is the
@@ -2379,13 +2384,18 @@ def __init__(
23792384 >>> |------ virtual-env
23802385
23812386 This is not supported with "local code" in Local Mode.
2387+ instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup
2388+ for specifying different instance groups for heterogeneous cluster.
2389+ For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64),
2390+ sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)]
23822391 """
23832392 self .image_uri = image_uri
23842393 self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
23852394 super (Estimator , self ).__init__ (
23862395 role ,
23872396 instance_count ,
23882397 instance_type ,
2398+ instance_groups ,
23892399 volume_size ,
23902400 volume_kms_key ,
23912401 max_run ,
0 commit comments