|
19 | 19 | from enum import Enum |
20 | 20 |
|
21 | 21 | import sagemaker |
22 | | -from sagemaker.amazon.amazon_estimator import RecordSet |
| 22 | +from sagemaker.amazon.amazon_estimator import RecordSet, AmazonAlgorithmEstimatorBase |
23 | 23 | from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa |
24 | 24 | from sagemaker.analytics import HyperparameterTuningJobAnalytics |
25 | 25 | from sagemaker.estimator import Framework |
@@ -358,7 +358,7 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim |
358 | 358 | estimator_cls, job_details["TrainingJobDefinition"] |
359 | 359 | ) |
360 | 360 | estimator = cls._prepare_estimator_from_job_description( |
361 | | - estimator_cls, job_details["TrainingJobDefinition"], sagemaker_session |
| 361 | + estimator_cls, job_details, sagemaker_session |
362 | 362 | ) |
363 | 363 | init_params = cls._prepare_init_params_from_job_description(job_details) |
364 | 364 |
|
@@ -497,16 +497,25 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details): |
497 | 497 | ) |
498 | 498 |
|
499 | 499 | @classmethod |
500 | | - def _prepare_estimator_from_job_description( |
501 | | - cls, estimator_cls, training_details, sagemaker_session |
502 | | - ): |
| 500 | + def _prepare_estimator_from_job_description(cls, estimator_cls, job_details, sagemaker_session): |
| 501 | + training_details = job_details["TrainingJobDefinition"] |
| 502 | + |
503 | 503 | # Swap name for static hyperparameters to what an estimator would expect |
504 | 504 | training_details["HyperParameters"] = training_details["StaticHyperParameters"] |
505 | 505 | del training_details["StaticHyperParameters"] |
506 | 506 |
|
507 | 507 | # Remove hyperparameter reserved by SageMaker for tuning jobs |
508 | 508 | del training_details["HyperParameters"]["_tuning_objective_metric"] |
509 | 509 |
|
| 510 | + # Add missing hyperparameters defined in the hyperparameter ranges, |
| 511 | + # as potentially required in the Amazon algorithm estimator's constructor |
| 512 | + if issubclass(estimator_cls, AmazonAlgorithmEstimatorBase): |
| 513 | + parameter_ranges = job_details["HyperParameterTuningJobConfig"]["ParameterRanges"] |
| 514 | + additional_hyperparameters = cls._extract_hyperparameters_from_parameter_ranges( |
| 515 | + parameter_ranges |
| 516 | + ) |
| 517 | + training_details["HyperParameters"].update(additional_hyperparameters) |
| 518 | + |
510 | 519 | # Add items expected by the estimator (but aren't needed otherwise) |
511 | 520 | training_details["TrainingJobName"] = "" |
512 | 521 | if "KmsKeyId" not in training_details["OutputDataConfig"]: |
@@ -559,6 +568,21 @@ def _prepare_parameter_ranges(cls, parameter_ranges): |
559 | 568 |
|
560 | 569 | return ranges |
561 | 570 |
|
| 571 | + @classmethod |
| 572 | + def _extract_hyperparameters_from_parameter_ranges(cls, parameter_ranges): |
| 573 | + hyperparameters = {} |
| 574 | + |
| 575 | + for parameter in parameter_ranges["CategoricalParameterRanges"]: |
| 576 | + hyperparameters[parameter["Name"]] = parameter["Values"][0] |
| 577 | + |
| 578 | + for parameter in parameter_ranges["ContinuousParameterRanges"]: |
| 579 | + hyperparameters[parameter["Name"]] = float(parameter["MinValue"]) |
| 580 | + |
| 581 | + for parameter in parameter_ranges["IntegerParameterRanges"]: |
| 582 | + hyperparameters[parameter["Name"]] = int(parameter["MinValue"]) |
| 583 | + |
| 584 | + return hyperparameters |
| 585 | + |
562 | 586 | def hyperparameter_ranges(self): |
563 | 587 | """Return the hyperparameter ranges in a dictionary to be used as part of a request for creating a |
564 | 588 | hyperparameter tuning job. |
|
0 commit comments