Skip to content

Commit e981daa

Browse files
authored
fix: SDK defaults for volume size, JS Estimator image uri region, Predictor str method (#3870)
1 parent 9ae348e commit e981daa

File tree

20 files changed

+784
-162
lines changed

20 files changed

+784
-162
lines changed

src/sagemaker/base_predictor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
NumpySerializer,
4848
)
4949
from sagemaker.session import production_variant, Session
50-
from sagemaker.utils import name_from_base
50+
from sagemaker.utils import name_from_base, stringify_object
5151

5252
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5353

@@ -75,6 +75,10 @@ def content_type(self) -> str:
7575
def accept(self) -> Tuple[str]:
7676
"""The content type(s) that are expected from the inference server."""
7777

78+
def __str__(self) -> str:
79+
"""Overriding str(*) method to make more human-readable."""
80+
return stringify_object(self)
81+
7882

7983
class Predictor(PredictorBase):
8084
"""Make prediction requests to an Amazon SageMaker endpoint."""

src/sagemaker/estimator.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
)
6464
from sagemaker.inputs import TrainingInput, FileSystemInput
6565
from sagemaker.instance_group import InstanceGroup
66+
from sagemaker.utils import instance_supports_kms
6667
from sagemaker.job import _Job
6768
from sagemaker.jumpstart.utils import (
6869
add_jumpstart_tags,
@@ -95,6 +96,7 @@
9596
)
9697
from sagemaker.workflow import is_pipeline_variable
9798
from sagemaker.workflow.entities import PipelineVariable
99+
from sagemaker.workflow.parameters import ParameterString
98100
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
99101

100102
logger = logging.getLogger(__name__)
@@ -599,10 +601,33 @@ def __init__(
599601
self.output_kms_key = resolve_value_from_config(
600602
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
601603
)
602-
self.volume_kms_key = resolve_value_from_config(
603-
volume_kms_key,
604-
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
605-
sagemaker_session=self.sagemaker_session,
604+
if instance_type is None or isinstance(instance_type, str):
605+
instance_type_for_volume_kms = instance_type
606+
elif isinstance(instance_type, ParameterString):
607+
instance_type_for_volume_kms = instance_type.default_value
608+
else:
609+
raise ValueError(f"Bad value for instance type: '{instance_type}'")
610+
611+
# KMS can only be attached to supported instances
612+
use_volume_kms_config = (
613+
(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms))
614+
or instance_groups is not None
615+
and any(
616+
[
617+
instance_supports_kms(instance_group.instance_type)
618+
for instance_group in instance_groups
619+
]
620+
)
621+
)
622+
623+
self.volume_kms_key = (
624+
resolve_value_from_config(
625+
volume_kms_key,
626+
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
627+
sagemaker_session=self.sagemaker_session,
628+
)
629+
if use_volume_kms_config
630+
else volume_kms_key
606631
)
607632

608633
# VPC configurations

src/sagemaker/instance_types.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,3 @@ def retrieve(
119119
tolerate_vulnerable_model,
120120
tolerate_deprecated_model,
121121
)
122-
123-
124-
def volume_size_supported(instance_type: str) -> bool:
125-
"""Returns True if SageMaker allows volume_size to be used for the instance type.
126-
127-
Raises:
128-
ValueError: If the instance type is improperly formatted.
129-
"""
130-
try:
131-
parts: List[str] = instance_type.split(".")
132-
if len(parts) != 3 or parts[0] != "ml":
133-
raise ValueError("Instance type must have 2 periods and start with 'ml'.")
134-
135-
# Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5
136-
# does not support attaching an EBS volume.
137-
family = parts[1]
138-
return "d" not in family and not family.startswith("g5")
139-
except Exception as e:
140-
raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}")

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
from copy import deepcopy
1616
from typing import Optional
17-
from sagemaker.instance_types import volume_size_supported
17+
from sagemaker.utils import volume_size_supported
1818
from sagemaker.jumpstart.constants import (
1919
JUMPSTART_DEFAULT_REGION_NAME,
2020
)

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
is_valid_model_id,
3737
resolve_model_intelligent_default_field,
3838
)
39-
from sagemaker.jumpstart.utils import stringify_object
39+
from sagemaker.utils import stringify_object
4040
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
4141
from sagemaker.predictor import PredictorBase
4242

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
425425
"""Sets image uri in kwargs based on default or override, returns full kwargs."""
426426

427427
kwargs.image_uri = kwargs.image_uri or image_uris.retrieve(
428-
region=None,
428+
region=kwargs.region,
429429
framework=None,
430430
image_scope=JumpStartScriptScope.TRAINING,
431431
model_id=kwargs.model_id,

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
get_init_kwargs,
2929
)
3030
from sagemaker.jumpstart.utils import is_valid_model_id
31-
from sagemaker.jumpstart.utils import stringify_object
31+
from sagemaker.utils import stringify_object
3232
from sagemaker.model import Model
3333
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3434
from sagemaker.predictor import PredictorBase

src/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,12 +530,6 @@ def resolve_estimator_intelligent_default_field(
530530
return field_val
531531

532532

533-
def stringify_object(obj: Any) -> str:
534-
"""Returns string representation of object, returning only non-None fields."""
535-
non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None}
536-
return f"{type(obj).__name__}: {str(non_none_atts)}"
537-
538-
539533
def is_valid_model_id(
540534
model_id: Optional[str],
541535
region: Optional[str] = None,

src/sagemaker/pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.transformer import Transformer
3535
from sagemaker.workflow.entities import PipelineVariable
3636
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
37+
from sagemaker.utils import instance_supports_kms
3738

3839

3940
class PipelineModel(object):
@@ -235,8 +236,12 @@ def deploy(
235236
container_startup_health_check_timeout=container_startup_health_check_timeout,
236237
)
237238
self.endpoint_name = endpoint_name or self.name
238-
kms_key = resolve_value_from_config(
239-
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
239+
kms_key = (
240+
resolve_value_from_config(
241+
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
242+
)
243+
if instance_supports_kms(instance_type)
244+
else kms_key
240245
)
241246

242247
data_capture_config_dict = None

src/sagemaker/session.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import botocore.config
3030
from botocore.exceptions import ClientError
3131
import six
32+
from sagemaker.utils import instance_supports_kms
3233

3334
import sagemaker.logs
3435
from sagemaker import vpc_utils, s3_utils
@@ -811,9 +812,17 @@ def train( # noqa: C901
811812
inferred_output_config = update_nested_dictionary_with_values_from_config(
812813
output_config, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, sagemaker_session=self
813814
)
815+
customer_supplied_kms_key = "VolumeKmsKeyId" in resource_config
814816
inferred_resource_config = update_nested_dictionary_with_values_from_config(
815817
resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self
816818
)
819+
if (
820+
not customer_supplied_kms_key
821+
and "InstanceType" in inferred_resource_config
822+
and not instance_supports_kms(inferred_resource_config["InstanceType"])
823+
and "VolumeKmsKeyId" in inferred_resource_config
824+
):
825+
del inferred_resource_config["VolumeKmsKeyId"]
817826
train_request = self._get_train_request(
818827
input_mode=input_mode,
819828
input_config=input_config,
@@ -3750,8 +3759,12 @@ def create_endpoint_config(
37503759
)
37513760
if tags is not None:
37523761
request["Tags"] = tags
3753-
kms_key = resolve_value_from_config(
3754-
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
3762+
kms_key = (
3763+
resolve_value_from_config(
3764+
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
3765+
)
3766+
if instance_supports_kms(instance_type)
3767+
else kms_key
37553768
)
37563769
if kms_key is not None:
37573770
request["KmsKeyId"] = kms_key
@@ -3844,7 +3857,16 @@ def create_endpoint_config_from_existing(
38443857

38453858
if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None:
38463859
request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId")
3847-
if KMS_KEY_ID not in request:
3860+
3861+
supports_kms = any(
3862+
[
3863+
instance_supports_kms(production_variant["InstanceType"])
3864+
for production_variant in production_variants
3865+
if "InstanceType" in production_variant
3866+
]
3867+
)
3868+
3869+
if KMS_KEY_ID not in request and supports_kms:
38483870
kms_key_from_config = resolve_value_from_config(
38493871
config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
38503872
)
@@ -4465,15 +4487,28 @@ def endpoint_from_production_variants(
44654487
Returns:
44664488
str: The name of the created ``Endpoint``.
44674489
"""
4490+
4491+
supports_kms = any(
4492+
[
4493+
instance_supports_kms(production_variant["InstanceType"])
4494+
for production_variant in production_variants
4495+
if "InstanceType" in production_variant
4496+
]
4497+
)
4498+
44684499
update_list_of_dicts_with_values_from_config(
44694500
production_variants,
44704501
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
44714502
required_key_paths=["CoreDumpConfig.DestinationS3Uri"],
44724503
sagemaker_session=self,
44734504
)
44744505
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
4475-
kms_key = resolve_value_from_config(
4476-
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
4506+
kms_key = (
4507+
resolve_value_from_config(
4508+
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
4509+
)
4510+
if supports_kms
4511+
else kms_key
44774512
)
44784513
tags = _append_project_tags(tags)
44794514
tags = self._append_sagemaker_config_tags(

0 commit comments

Comments
 (0)