@@ -49,6 +49,7 @@ def __init__(
4949 channel_type = None ,
5050 content_type = None ,
5151 s3_data_type = None ,
52+ sample_weight_attribute_name = None ,
5253 ):
5354 """Convert an S3 Uri or a list of S3 Uri to an AutoMLInput object.
5455
@@ -67,13 +68,16 @@ def __init__(
6768 The content type of the data from the input source.
6869 s3_data_type (str, PipelineVariable): The data type for S3 data source.
6970 Valid values: ManifestFile or S3Prefix.
71+ sample_weight_attribute_name (str, PipelineVariable):
72+ the name of the dataset column representing sample weights
7073 """
7174 self .inputs = inputs
7275 self .target_attribute_name = target_attribute_name
7376 self .compression = compression
7477 self .channel_type = channel_type
7578 self .content_type = content_type
7679 self .s3_data_type = s3_data_type
80+ self .sample_weight_attribute_name = sample_weight_attribute_name
7781
7882 def to_request_dict (self ):
7983 """Generates a request dictionary using the parameters provided to the class."""
@@ -96,6 +100,8 @@ def to_request_dict(self):
96100 input_entry ["ContentType" ] = self .content_type
97101 if self .s3_data_type is not None :
98102 input_entry ["DataSource" ]["S3DataSource" ]["S3DataType" ] = self .s3_data_type
103+ if self .sample_weight_attribute_name is not None :
104+ input_entry ["SampleWeightAttributeName" ] = self .sample_weight_attribute_name
99105 auto_ml_input .append (input_entry )
100106 return auto_ml_input
101107
@@ -129,6 +135,7 @@ def __init__(
129135 mode : Optional [str ] = None ,
130136 auto_generate_endpoint_name : Optional [bool ] = None ,
131137 endpoint_name : Optional [str ] = None ,
138+ sample_weight_attribute_name : str = None ,
132139 ):
133140 """Initialize the an AutoML object.
134141
@@ -179,6 +186,8 @@ def __init__(
179186 model deployment if the endpoint name is not generated automatically.
180187 Specify the endpoint_name if and only if
181188 auto_generate_endpoint_name is set to False
189+ sample_weight_attribute_name (str): The name of dataset column representing
190+ sample weights.
182191
183192 Returns:
184193 AutoML object.
@@ -234,6 +243,7 @@ def __init__(
234243 )
235244
236245 self ._check_problem_type_and_job_objective (self .problem_type , self .job_objective )
246+ self .sample_weight_attribute_name = sample_weight_attribute_name
237247
238248 @runnable_by_pipeline
239249 def fit (self , inputs = None , wait = True , logs = True , job_name = None ):
@@ -342,6 +352,9 @@ def attach(cls, auto_ml_job_name, sagemaker_session=None):
342352 "AutoGenerateEndpointName" , False
343353 ),
344354 endpoint_name = auto_ml_job_desc .get ("ModelDeployConfig" , {}).get ("EndpointName" ),
355+ sample_weight_attribute_name = auto_ml_job_desc ["InputDataConfig" ][0 ].get (
356+ "SampleWeightAttributeName" , None
357+ ),
345358 )
346359 amlj .current_job_name = auto_ml_job_name
347360 amlj .latest_auto_ml_job = auto_ml_job_name # pylint: disable=W0201
@@ -867,6 +880,7 @@ def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):
867880 auto_ml .target_attribute_name ,
868881 auto_ml .content_type ,
869882 auto_ml .s3_data_type ,
883+ auto_ml .sample_weight_attribute_name ,
870884 )
871885 output_config = _Job ._prepare_output_config (auto_ml .output_path , auto_ml .output_kms_key )
872886
@@ -932,6 +946,7 @@ def _format_inputs_to_input_config(
932946 target_attribute_name = None ,
933947 content_type = None ,
934948 s3_data_type = None ,
949+ sample_weight_attribute_name = None ,
935950 ):
936951 """Convert inputs to AutoML InputDataConfig.
937952
@@ -961,6 +976,8 @@ def _format_inputs_to_input_config(
961976 channel ["ContentType" ] = content_type
962977 if s3_data_type is not None :
963978 channel ["DataSource" ]["S3DataSource" ]["S3DataType" ] = s3_data_type
979+ if sample_weight_attribute_name is not None :
980+ channel ["SampleWeightAttributeName" ] = sample_weight_attribute_name
964981 channels .append (channel )
965982 elif isinstance (inputs , list ):
966983 for input_entry in inputs :
@@ -974,6 +991,8 @@ def _format_inputs_to_input_config(
974991 channel ["ContentType" ] = content_type
975992 if s3_data_type is not None :
976993 channel ["DataSource" ]["S3DataSource" ]["S3DataType" ] = s3_data_type
994+ if sample_weight_attribute_name is not None :
995+ channel ["SampleWeightAttributeName" ] = sample_weight_attribute_name
977996 channels .append (channel )
978997 else :
979998 msg = (
0 commit comments