|
29 | 29 | import botocore.config |
30 | 30 | from botocore.exceptions import ClientError |
31 | 31 | import six |
| 32 | +from sagemaker.utils import instance_supports_kms |
32 | 33 |
|
33 | 34 | import sagemaker.logs |
34 | 35 | from sagemaker import vpc_utils, s3_utils |
@@ -811,9 +812,17 @@ def train( # noqa: C901 |
811 | 812 | inferred_output_config = update_nested_dictionary_with_values_from_config( |
812 | 813 | output_config, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, sagemaker_session=self |
813 | 814 | ) |
| 815 | + customer_supplied_kms_key = "VolumeKmsKeyId" in resource_config |
814 | 816 | inferred_resource_config = update_nested_dictionary_with_values_from_config( |
815 | 817 | resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self |
816 | 818 | ) |
| 819 | + if ( |
| 820 | + not customer_supplied_kms_key |
| 821 | + and "InstanceType" in inferred_resource_config |
| 822 | + and not instance_supports_kms(inferred_resource_config["InstanceType"]) |
| 823 | + and "VolumeKmsKeyId" in inferred_resource_config |
| 824 | + ): |
| 825 | + del inferred_resource_config["VolumeKmsKeyId"] |
817 | 826 | train_request = self._get_train_request( |
818 | 827 | input_mode=input_mode, |
819 | 828 | input_config=input_config, |
@@ -3750,8 +3759,12 @@ def create_endpoint_config( |
3750 | 3759 | ) |
3751 | 3760 | if tags is not None: |
3752 | 3761 | request["Tags"] = tags |
3753 | | - kms_key = resolve_value_from_config( |
3754 | | - kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self |
| 3762 | + kms_key = ( |
| 3763 | + resolve_value_from_config( |
| 3764 | + kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self |
| 3765 | + ) |
| 3766 | + if instance_supports_kms(instance_type) |
| 3767 | + else kms_key |
3755 | 3768 | ) |
3756 | 3769 | if kms_key is not None: |
3757 | 3770 | request["KmsKeyId"] = kms_key |
@@ -3844,7 +3857,16 @@ def create_endpoint_config_from_existing( |
3844 | 3857 |
|
3845 | 3858 | if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None: |
3846 | 3859 | request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId") |
3847 | | - if KMS_KEY_ID not in request: |
| 3860 | + |
| 3861 | + supports_kms = any( |
| 3862 | + [ |
| 3863 | + instance_supports_kms(production_variant["InstanceType"]) |
| 3864 | + for production_variant in production_variants |
| 3865 | + if "InstanceType" in production_variant |
| 3866 | + ] |
| 3867 | + ) |
| 3868 | + |
| 3869 | + if KMS_KEY_ID not in request and supports_kms: |
3848 | 3870 | kms_key_from_config = resolve_value_from_config( |
3849 | 3871 | config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self |
3850 | 3872 | ) |
@@ -4465,15 +4487,28 @@ def endpoint_from_production_variants( |
4465 | 4487 | Returns: |
4466 | 4488 | str: The name of the created ``Endpoint``. |
4467 | 4489 | """ |
| 4490 | + |
| 4491 | + supports_kms = any( |
| 4492 | + [ |
| 4493 | + instance_supports_kms(production_variant["InstanceType"]) |
| 4494 | + for production_variant in production_variants |
| 4495 | + if "InstanceType" in production_variant |
| 4496 | + ] |
| 4497 | + ) |
| 4498 | + |
4468 | 4499 | update_list_of_dicts_with_values_from_config( |
4469 | 4500 | production_variants, |
4470 | 4501 | ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, |
4471 | 4502 | required_key_paths=["CoreDumpConfig.DestinationS3Uri"], |
4472 | 4503 | sagemaker_session=self, |
4473 | 4504 | ) |
4474 | 4505 | config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} |
4475 | | - kms_key = resolve_value_from_config( |
4476 | | - kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self |
| 4506 | + kms_key = ( |
| 4507 | + resolve_value_from_config( |
| 4508 | + kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self |
| 4509 | + ) |
| 4510 | + if supports_kms |
| 4511 | + else kms_key |
4477 | 4512 | ) |
4478 | 4513 | tags = _append_project_tags(tags) |
4479 | 4514 | tags = self._append_sagemaker_config_tags( |
|
0 commit comments