1515import os
1616
1717import sagemaker
18- from sagemaker import job , utils , model
18+ from sagemaker import job , model , utils
1919from sagemaker .amazon import amazon_estimator
2020
2121
@@ -48,14 +48,19 @@ def prepare_framework(estimator, s3_operations):
4848 estimator ._hyperparameters [model .SAGEMAKER_REGION_PARAM_NAME ] = estimator .sagemaker_session .boto_region_name
4949
5050
51- def prepare_amazon_algorithm_estimator (estimator , inputs ):
51+ def prepare_amazon_algorithm_estimator (estimator , inputs , mini_batch_size = None ):
5252 """ Set up amazon algorithm estimator, adding the required `feature_dim` hyperparameter from training data.
5353
5454 Args:
5555 estimator (sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
5656 An estimator for a built-in Amazon algorithm to get information from and update.
57- inputs (single or list of sagemaker.amazon.amazon_estimator.RecordSet):
58- The training data, must be in RecordSet format.
57+ inputs: The training data.
58+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
59+ Amazon :class:~`Record` objects serialized and stored in S3.
60+ For use with an estimator for an Amazon algorithm.
61+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
62+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
63+ a different channel of training data.
5964 """
6065 if isinstance (inputs , list ):
6166 for record in inputs :
@@ -66,22 +71,39 @@ def prepare_amazon_algorithm_estimator(estimator, inputs):
6671 estimator .feature_dim = inputs .feature_dim
6772 else :
6873 raise TypeError ('Training data must be represented in RecordSet or list of RecordSets' )
74+ estimator .mini_batch_size = mini_batch_size
6975
7076
71- def training_config (estimator , inputs = None , job_name = None ): # noqa: C901 - suppress complexity warning for this method
72- """Export Airflow training config from an estimator
77+ def training_base_config (estimator , inputs = None , job_name = None , mini_batch_size = None ):
78+ """Export Airflow base training config from an estimator
7379
7480 Args:
75- estimator (sagemaker.estimator.EstimatroBase ):
81+ estimator (sagemaker.estimator.EstimatorBase ):
7682 The estimator to export training config from. Can be a BYO estimator,
7783 Framework estimator or Amazon algorithm estimator.
78- inputs (str, dict, single or list of sagemaker.amazon.amazon_estimator.RecordSet):
79- The training data.
84+ inputs: Information about the training data. Please refer to the ``fit()`` method of
85+ the associated estimator, as this can take any of the following forms:
86+
87+ * (str) - The S3 location where training data is saved.
88+ * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
89+ training data, you can specify a dict mapping channel names
90+ to strings or :func:`~sagemaker.session.s3_input` objects.
91+ * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
92+ additional information about the training dataset. See :func:`sagemaker.session.s3_input`
93+ for full details.
94+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
95+ Amazon :class:~`Record` objects serialized and stored in S3.
96+ For use with an estimator for an Amazon algorithm.
97+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
98+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
99+ a different channel of training data.
100+
80101 job_name (str): Specify a training job name if needed.
102+ mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
103+ Amazon algorithm. For other estimators, batch size should be specified in the estimator.
81104
82- Returns:
83- A dict of training config that can be directly used by SageMakerTrainingOperator
84- in Airflow.
105+ Returns (dict):
106+ Training config that can be directly used by SageMakerTrainingOperator in Airflow.
85107 """
86108 default_bucket = estimator .sagemaker_session .default_bucket ()
87109 s3_operations = {}
@@ -99,8 +121,7 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
99121 prepare_framework (estimator , s3_operations )
100122
101123 elif isinstance (estimator , amazon_estimator .AmazonAlgorithmEstimatorBase ):
102- prepare_amazon_algorithm_estimator (estimator , inputs )
103-
124+ prepare_amazon_algorithm_estimator (estimator , inputs , mini_batch_size )
104125 job_config = job ._Job ._load_config (inputs , estimator , expand_role = False , validate_uri = False )
105126
106127 train_config = {
@@ -109,7 +130,6 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
109130 'TrainingInputMode' : estimator .input_mode
110131 },
111132 'OutputDataConfig' : job_config ['output_config' ],
112- 'TrainingJobName' : estimator ._current_job_name ,
113133 'StoppingCondition' : job_config ['stop_condition' ],
114134 'ResourceConfig' : job_config ['resource_config' ],
115135 'RoleArn' : job_config ['role' ],
@@ -127,10 +147,125 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
127147 if hyperparameters and len (hyperparameters ) > 0 :
128148 train_config ['HyperParameters' ] = hyperparameters
129149
130- if estimator .tags is not None :
131- train_config ['Tags' ] = estimator .tags
132-
133150 if s3_operations :
134151 train_config ['S3Operations' ] = s3_operations
135152
136153 return train_config
154+
155+
156+ def training_config (estimator , inputs = None , job_name = None , mini_batch_size = None ):
157+ """Export Airflow training config from an estimator
158+
159+ Args:
160+ estimator (sagemaker.estimator.EstimatorBase):
161+ The estimator to export training config from. Can be a BYO estimator,
162+ Framework estimator or Amazon algorithm estimator.
163+ inputs: Information about the training data. Please refer to the ``fit()`` method of
164+ the associated estimator, as this can take any of the following forms:
165+
166+ * (str) - The S3 location where training data is saved.
167+ * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
168+ training data, you can specify a dict mapping channel names
169+ to strings or :func:`~sagemaker.session.s3_input` objects.
170+ * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
171+ additional information about the training dataset. See :func:`sagemaker.session.s3_input`
172+ for full details.
173+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
174+ Amazon :class:~`Record` objects serialized and stored in S3.
175+ For use with an estimator for an Amazon algorithm.
176+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
177+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
178+ a different channel of training data.
179+
180+ job_name (str): Specify a training job name if needed.
181+ mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
182+ Amazon algorithm. For other estimators, batch size should be specified in the estimator.
183+
184+ Returns (dict):
185+ Training config that can be directly used by SageMakerTrainingOperator in Airflow.
186+ """
187+
188+ train_config = training_base_config (estimator , inputs , job_name , mini_batch_size )
189+
190+ train_config ['TrainingJobName' ] = estimator ._current_job_name
191+
192+ if estimator .tags is not None :
193+ train_config ['Tags' ] = estimator .tags
194+
195+ return train_config
196+
197+
198+ def tuning_config (tuner , inputs , job_name = None ):
199+ """Export Airflow tuning config from an estimator
200+
201+ Args:
202+ tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning config from.
203+ inputs: Information about the training data. Please refer to the ``fit()`` method of
204+ the associated estimator in the tuner, as this can take any of the following forms:
205+
206+ * (str) - The S3 location where training data is saved.
207+ * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
208+ training data, you can specify a dict mapping channel names
209+ to strings or :func:`~sagemaker.session.s3_input` objects.
210+ * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
211+ additional information about the training dataset. See :func:`sagemaker.session.s3_input`
212+ for full details.
213+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
214+ Amazon :class:~`Record` objects serialized and stored in S3.
215+ For use with an estimator for an Amazon algorithm.
216+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
217+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
218+ a different channel of training data.
219+
220+ job_name (str): Specify a tuning job name if needed.
221+
222+ Returns (dict):
223+ Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
224+ """
225+ train_config = training_base_config (tuner .estimator , inputs )
226+ hyperparameters = train_config .pop ('HyperParameters' , None )
227+ s3_operations = train_config .pop ('S3Operations' , None )
228+
229+ if hyperparameters and len (hyperparameters ) > 0 :
230+ tuner .static_hyperparameters = \
231+ {utils .to_str (k ): utils .to_str (v ) for (k , v ) in hyperparameters .items ()}
232+
233+ if job_name is not None :
234+ tuner ._current_job_name = job_name
235+ else :
236+ base_name = tuner .base_tuning_job_name or utils .base_name_from_image (tuner .estimator .train_image ())
237+ tuner ._current_job_name = utils .airflow_name_from_base (base_name , tuner .TUNING_JOB_NAME_MAX_LENGTH , True )
238+
239+ for hyperparameter_name in tuner ._hyperparameter_ranges .keys ():
240+ tuner .static_hyperparameters .pop (hyperparameter_name , None )
241+
242+ train_config ['StaticHyperParameters' ] = tuner .static_hyperparameters
243+
244+ tune_config = {
245+ 'HyperParameterTuningJobName' : tuner ._current_job_name ,
246+ 'HyperParameterTuningJobConfig' : {
247+ 'Strategy' : tuner .strategy ,
248+ 'HyperParameterTuningJobObjective' : {
249+ 'Type' : tuner .objective_type ,
250+ 'MetricName' : tuner .objective_metric_name ,
251+ },
252+ 'ResourceLimits' : {
253+ 'MaxNumberOfTrainingJobs' : tuner .max_jobs ,
254+ 'MaxParallelTrainingJobs' : tuner .max_parallel_jobs ,
255+ },
256+ 'ParameterRanges' : tuner .hyperparameter_ranges (),
257+ },
258+ 'TrainingJobDefinition' : train_config
259+ }
260+
261+ if tuner .metric_definitions is not None :
262+ tune_config ['TrainingJobDefinition' ]['AlgorithmSpecification' ]['MetricDefinitions' ] = \
263+ tuner .metric_definitions
264+
265+ if tuner .tags is not None :
266+ tune_config ['Tags' ] = tuner .tags
267+
268+ if s3_operations is not None :
269+ tune_config ['S3Operations' ] = s3_operations
270+
271+ return tune_config
0 commit comments