@@ -131,7 +131,7 @@ class Model(ModelBase, InferenceRecommenderMixin):
131131 def __init__ (
132132 self ,
133133 image_uri : Union [str , PipelineVariable ],
134- model_data : Optional [Union [str , PipelineVariable ]] = None ,
134+ model_data : Optional [Union [str , PipelineVariable , dict ]] = None ,
135135 role : Optional [str ] = None ,
136136 predictor_cls : Optional [callable ] = None ,
137137 env : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
@@ -152,8 +152,8 @@ def __init__(
152152
153153 Args:
154154 image_uri (str or PipelineVariable): A Docker image URI.
155- model_data (str or PipelineVariable): The S3 location of a SageMaker
156- model data ``.tar.gz`` file (default: None).
155+ model_data (str or PipelineVariable or dict ): Location
156+ of SageMaker model data (default: None).
157157 role (str): An AWS IAM role (either name or full ARN). The Amazon
158158 SageMaker training jobs and APIs that create Amazon SageMaker
159159 endpoints use this role to access training data and model
@@ -455,6 +455,11 @@ def register(
455455 """
456456 if self .model_data is None :
457457 raise ValueError ("SageMaker Model Package cannot be created without model data." )
458+ if isinstance (self .model_data , dict ):
459+ raise ValueError (
460+ "SageMaker Model Package currently cannot be created with ModelDataSource."
461+ )
462+
458463 if image_uri is not None :
459464 self .image_uri = image_uri
460465
@@ -600,6 +605,7 @@ def prepare_container_def(
600605 )
601606 self ._upload_code (deploy_key_prefix , repack = is_repack )
602607 deploy_env .update (self ._script_mode_env_vars ())
608+
603609 return sagemaker .container_def (
604610 self .image_uri ,
605611 self .repacked_model_data or self .model_data ,
@@ -639,6 +645,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
639645 )
640646
641647 if repack and self .model_data is not None and self .entry_point is not None :
648+ if isinstance (self .model_data , dict ):
649+ logging .warning ("ModelDataSource currently doesn't support model repacking" )
650+ return
642651 if is_pipeline_variable (self .model_data ):
643652 # model is not yet there, defer repacking to later during pipeline execution
644653 if not isinstance (self .sagemaker_session , PipelineSession ):
@@ -765,10 +774,16 @@ def _create_sagemaker_model(
765774 # _base_name, model_name are not needed under PipelineSession.
766775 # the model_data may be Pipeline variable
767776 # which may break the _base_name generation
777+ model_uri = None
778+ if isinstance (self .model_data , (str , PipelineVariable )):
779+ model_uri = self .model_data
780+ elif isinstance (self .model_data , dict ):
781+ model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
782+
768783 self ._ensure_base_name_if_needed (
769784 image_uri = container_def ["Image" ],
770785 script_uri = self .source_dir ,
771- model_uri = self . model_data ,
786+ model_uri = model_uri ,
772787 )
773788 self ._set_model_name_if_needed ()
774789
@@ -1110,6 +1125,8 @@ def compile(
11101125 raise ValueError ("You must provide a compilation job name" )
11111126 if self .model_data is None :
11121127 raise ValueError ("You must provide an S3 path to the compressed model artifacts." )
1128+ if isinstance (self .model_data , dict ):
1129+ raise ValueError ("Compiling model data from ModelDataSource is currently not supported" )
11131130
11141131 framework_version = framework_version or self ._get_framework_version ()
11151132
@@ -1301,7 +1318,7 @@ def deploy(
13011318
13021319 tags = add_jumpstart_tags (
13031320 tags = tags ,
1304- inference_model_uri = self .model_data ,
1321+ inference_model_uri = self .model_data if isinstance ( self . model_data , str ) else None ,
13051322 inference_script_uri = self .source_dir ,
13061323 )
13071324
@@ -1545,7 +1562,7 @@ class FrameworkModel(Model):
15451562
15461563 def __init__ (
15471564 self ,
1548- model_data : Union [str , PipelineVariable ],
1565+ model_data : Union [str , PipelineVariable , dict ],
15491566 image_uri : Union [str , PipelineVariable ],
15501567 role : Optional [str ] = None ,
15511568 entry_point : Optional [str ] = None ,
@@ -1563,8 +1580,8 @@ def __init__(
15631580 """Initialize a ``FrameworkModel``.
15641581
15651582 Args:
1566- model_data (str or PipelineVariable): The S3 location of a SageMaker
1567- model data ``.tar.gz`` file .
1583+ model_data (str or PipelineVariable or dict ): The S3 location of
1584+ SageMaker model data.
15681585 image_uri (str or PipelineVariable): A Docker image URI.
15691586 role (str): An IAM role name or ARN for SageMaker to access AWS
15701587 resources on your behalf.
@@ -1758,6 +1775,11 @@ def __init__(
17581775 ``model_data`` is not required.
17591776 **kwargs: Additional kwargs passed to the Model constructor.
17601777 """
1778+ if isinstance (model_data , dict ):
1779+ raise ValueError (
1780+ "Creating ModelPackage with ModelDataSource is currently not supported"
1781+ )
1782+
17611783 super (ModelPackage , self ).__init__ (
17621784 role = role , model_data = model_data , image_uri = None , ** kwargs
17631785 )
0 commit comments