3535 EDGE_PACKAGING_KMS_KEY_ID_PATH ,
3636 EDGE_PACKAGING_ROLE_ARN_PATH ,
3737 MODEL_CONTAINERS_PATH ,
38+ EDGE_PACKAGING_RESOURCE_KEY_PATH ,
3839 MODEL_VPC_CONFIG_PATH ,
3940 MODEL_ENABLE_NETWORK_ISOLATION_PATH ,
4041 MODEL_EXECUTION_ROLE_ARN_PATH ,
5051from sagemaker .predictor import PredictorBase
5152from sagemaker .serverless import ServerlessInferenceConfig
5253from sagemaker .transformer import Transformer
53- from sagemaker .jumpstart .utils import add_jumpstart_tags , get_jumpstart_base_name_if_jumpstart_model
54+ from sagemaker .jumpstart .utils import (
55+ add_jumpstart_tags ,
56+ get_jumpstart_base_name_if_jumpstart_model ,
57+ )
5458from sagemaker .utils import (
5559 unique_name_from_base ,
5660 update_container_with_inference_params ,
6367from sagemaker .workflow import is_pipeline_variable
6468from sagemaker .workflow .entities import PipelineVariable
6569from sagemaker .workflow .pipeline_context import runnable_by_pipeline , PipelineSession
66- from sagemaker .inference_recommender .inference_recommender_mixin import InferenceRecommenderMixin
70+ from sagemaker .inference_recommender .inference_recommender_mixin import (
71+ InferenceRecommenderMixin ,
72+ )
6773
6874LOGGER = logging .getLogger ("sagemaker" )
6975
7076NEO_ALLOWED_FRAMEWORKS = set (
7177 ["mxnet" , "tensorflow" , "keras" , "pytorch" , "onnx" , "xgboost" , "tflite" ]
7278)
7379
74- NEO_IOC_TARGET_DEVICES = ["ml_c4" , "ml_c5" , "ml_m4" , "ml_m5" , "ml_p2" , "ml_p3" , "ml_g4dn" ]
80+ NEO_IOC_TARGET_DEVICES = [
81+ "ml_c4" ,
82+ "ml_c5" ,
83+ "ml_m4" ,
84+ "ml_m5" ,
85+ "ml_p2" ,
86+ "ml_p3" ,
87+ "ml_g4dn" ,
88+ ]
7589
7690NEO_MULTIVERSION_UNSUPPORTED = [
7791 "imx8mplus" ,
@@ -300,7 +314,9 @@ def __init__(
300314 self ._base_name = None
301315 self .sagemaker_session = sagemaker_session
302316 self .role = resolve_value_from_config (
303- role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
317+ role ,
318+ MODEL_EXECUTION_ROLE_ARN_PATH ,
319+ sagemaker_session = self .sagemaker_session ,
304320 )
305321 self .vpc_config = resolve_value_from_config (
306322 vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
@@ -585,7 +601,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
585601 local_code = utils .get_config_value ("local.local_code" , self .sagemaker_session .config )
586602
587603 bucket , key_prefix = s3 .determine_bucket_and_prefix (
588- bucket = self .bucket , key_prefix = key_prefix , sagemaker_session = self .sagemaker_session
604+ bucket = self .bucket ,
605+ key_prefix = key_prefix ,
606+ sagemaker_session = self .sagemaker_session ,
589607 )
590608
591609 if (self .sagemaker_session .local_mode and local_code ) or self .entry_point is None :
@@ -633,7 +651,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
633651 else :
634652 repacked_model_data = "s3://" + "/" .join ([bucket , key_prefix , "model.tar.gz" ])
635653 self .uploaded_code = fw_utils .UploadedCode (
636- s3_prefix = repacked_model_data , script_name = os .path .basename (self .entry_point )
654+ s3_prefix = repacked_model_data ,
655+ script_name = os .path .basename (self .entry_point ),
637656 )
638657
639658 LOGGER .info (
@@ -693,7 +712,11 @@ def enable_network_isolation(self):
693712 return False if not self ._enable_network_isolation else self ._enable_network_isolation
694713
695714 def _create_sagemaker_model (
696- self , instance_type = None , accelerator_type = None , tags = None , serverless_inference_config = None
715+ self ,
716+ instance_type = None ,
717+ accelerator_type = None ,
718+ tags = None ,
719+ serverless_inference_config = None ,
697720 ):
698721 """Create a SageMaker Model Entity
699722
@@ -734,10 +757,14 @@ def _create_sagemaker_model(
734757 self ._init_sagemaker_session_if_does_not_exist (instance_type )
735758 # Depending on the instance type, a local session (or) a session is initialized.
736759 self .role = resolve_value_from_config (
737- self .role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
760+ self .role ,
761+ MODEL_EXECUTION_ROLE_ARN_PATH ,
762+ sagemaker_session = self .sagemaker_session ,
738763 )
739764 self .vpc_config = resolve_value_from_config (
740- self .vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
765+ self .vpc_config ,
766+ MODEL_VPC_CONFIG_PATH ,
767+ sagemaker_session = self .sagemaker_session ,
741768 )
742769 self ._enable_network_isolation = resolve_value_from_config (
743770 self ._enable_network_isolation ,
@@ -955,11 +982,16 @@ def package_for_edge(
955982 job_name = f"packaging{ self ._compilation_job_name [11 :]} "
956983 self ._init_sagemaker_session_if_does_not_exist (None )
957984 s3_kms_key = resolve_value_from_config (
958- s3_kms_key , EDGE_PACKAGING_KMS_KEY_ID_PATH , sagemaker_session = self .sagemaker_session
985+ s3_kms_key ,
986+ EDGE_PACKAGING_KMS_KEY_ID_PATH ,
987+ sagemaker_session = self .sagemaker_session ,
959988 )
960989 role = resolve_value_from_config (
961990 role , EDGE_PACKAGING_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
962991 )
992+ resource_key = resolve_value_from_config (
993+ resource_key , EDGE_PACKAGING_RESOURCE_KEY_PATH , sagemaker_session = self .sagemaker_session
994+ )
963995 if role is not None :
964996 role = self .sagemaker_session .expand_role (role )
965997 config = self ._edge_packaging_job_config (
@@ -1065,7 +1097,9 @@ def compile(
10651097
10661098 self ._init_sagemaker_session_if_does_not_exist (target_instance_family )
10671099 role = resolve_value_from_config (
1068- role , COMPILATION_JOB_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
1100+ role ,
1101+ COMPILATION_JOB_ROLE_ARN_PATH ,
1102+ sagemaker_session = self .sagemaker_session ,
10691103 )
10701104 if not role :
10711105 # Originally IAM role was a required parameter.
@@ -1232,10 +1266,14 @@ def deploy(
12321266 self ._init_sagemaker_session_if_does_not_exist (instance_type )
12331267 # Depending on the instance type, a local session (or) a session is initialized.
12341268 self .role = resolve_value_from_config (
1235- self .role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
1269+ self .role ,
1270+ MODEL_EXECUTION_ROLE_ARN_PATH ,
1271+ sagemaker_session = self .sagemaker_session ,
12361272 )
12371273 self .vpc_config = resolve_value_from_config (
1238- self .vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
1274+ self .vpc_config ,
1275+ MODEL_VPC_CONFIG_PATH ,
1276+ sagemaker_session = self .sagemaker_session ,
12391277 )
12401278 self ._enable_network_isolation = resolve_value_from_config (
12411279 self ._enable_network_isolation ,
@@ -1244,7 +1282,9 @@ def deploy(
12441282 )
12451283
12461284 tags = add_jumpstart_tags (
1247- tags = tags , inference_model_uri = self .model_data , inference_script_uri = self .source_dir
1285+ tags = tags ,
1286+ inference_model_uri = self .model_data ,
1287+ inference_script_uri = self .source_dir ,
12481288 )
12491289
12501290 if self .role is None :
@@ -1292,7 +1332,9 @@ def deploy(
12921332 compiled_model_suffix = None if is_serverless else "-" .join (instance_type .split ("." )[:- 1 ])
12931333 if self ._is_compiled_model and not is_serverless :
12941334 self ._ensure_base_name_if_needed (
1295- image_uri = self .image_uri , script_uri = self .source_dir , model_uri = self .model_data
1335+ image_uri = self .image_uri ,
1336+ script_uri = self .source_dir ,
1337+ model_uri = self .model_data ,
12961338 )
12971339 if self ._base_name is not None :
12981340 self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
@@ -1673,7 +1715,12 @@ class ModelPackage(Model):
16731715 """A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
16741716
16751717 def __init__ (
1676- self , role = None , model_data = None , algorithm_arn = None , model_package_arn = None , ** kwargs
1718+ self ,
1719+ role = None ,
1720+ model_data = None ,
1721+ algorithm_arn = None ,
1722+ model_package_arn = None ,
1723+ ** kwargs ,
16771724 ):
16781725 """Initialize a SageMaker ModelPackage.
16791726
0 commit comments