7777)
7878from sagemaker .compute_resource_requirements .resource_requirements import ResourceRequirements
7979from sagemaker .enums import EndpointType
80- from sagemaker .session import get_add_model_package_inference_args
80+ from sagemaker .session import (
81+ get_add_model_package_inference_args ,
82+ get_update_model_package_inference_args ,
83+ )
8184
8285# Setting LOGGER for backward compatibility, in case users import it...
8386logger = LOGGER = logging .getLogger ("sagemaker" )
@@ -423,6 +426,7 @@ def register(
423426 nearest_model_name : Optional [Union [str , PipelineVariable ]] = None ,
424427 data_input_configuration : Optional [Union [str , PipelineVariable ]] = None ,
425428 skip_model_validation : Optional [Union [str , PipelineVariable ]] = None ,
429+ source_uri : Optional [Union [str , PipelineVariable ]] = None ,
426430 ):
427431 """Creates a model package for creating SageMaker models or listing on Marketplace.
428432
@@ -472,17 +476,14 @@ def register(
472476 (default: None).
473477 skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
474478 validation. Values can be "All" or "None" (default: None).
479+ source_uri (str or PipelineVariable): The URI of the source for the model package
480+ (default: None).
475481
476482 Returns:
477483 A `sagemaker.model.ModelPackage` instance or pipeline step arguments
478484 in case the Model instance is built with
479485 :class:`~sagemaker.workflow.pipeline_context.PipelineSession`
480486 """
481- if isinstance (self .model_data , dict ):
482- raise ValueError (
483- "SageMaker Model Package currently cannot be created with ModelDataSource."
484- )
485-
486487 if content_types is not None :
487488 self .content_types = content_types
488489
@@ -513,6 +514,12 @@ def register(
513514 "Image" : self .image_uri ,
514515 }
515516
517+ if isinstance (self .model_data , dict ):
518+ raise ValueError (
519+ "Un-versioned SageMaker Model Package currently cannot be "
520+ "created with ModelDataSource."
521+ )
522+
516523 if self .model_data is not None :
517524 container_def ["ModelDataUrl" ] = self .model_data
518525
@@ -536,6 +543,7 @@ def register(
536543 sample_payload_url = sample_payload_url ,
537544 task = task ,
538545 skip_model_validation = skip_model_validation ,
546+ source_uri = source_uri ,
539547 )
540548 model_package = self .sagemaker_session .create_model_package_from_containers (
541549 ** model_pkg_args
@@ -2040,8 +2048,9 @@ def __init__(
20402048 endpoints use this role to access training data and model
20412049 artifacts. After the endpoint is created, the inference code
20422050 might use the IAM role, if it needs to access an AWS resource.
2043- model_data (str): The S3 location of a SageMaker model data
2044- ``.tar.gz`` file. Must be provided if algorithm_arn is provided.
2051+ model_data (str or dict[str, Any]): The S3 location of a SageMaker model data
2052+ ``.tar.gz`` file or a dictionary representing a ``ModelDataSource``
2053+ object. Must be provided if algorithm_arn is provided.
20452054 algorithm_arn (str): algorithm arn used to train the model, can be
20462055 just the name if your account owns the algorithm. Must also
20472056 provide ``model_data``.
@@ -2050,11 +2059,6 @@ def __init__(
20502059 ``model_data`` is not required.
20512060 **kwargs: Additional kwargs passed to the Model constructor.
20522061 """
2053- if isinstance (model_data , dict ):
2054- raise ValueError (
2055- "Creating ModelPackage with ModelDataSource is currently not supported"
2056- )
2057-
20582062 super (ModelPackage , self ).__init__ (
20592063 role = role , model_data = model_data , image_uri = None , ** kwargs
20602064 )
@@ -2222,6 +2226,74 @@ def update_customer_metadata(self, customer_metadata_properties: Dict[str, str])
22222226 sagemaker_session = self .sagemaker_session or sagemaker .Session ()
22232227 sagemaker_session .sagemaker_client .update_model_package (** update_metadata_args )
22242228
2229+ def update_inference_specification (
2230+ self ,
2231+ containers : Dict = None ,
2232+ image_uris : List [str ] = None ,
2233+ content_types : List [str ] = None ,
2234+ response_types : List [str ] = None ,
2235+ inference_instances : List [str ] = None ,
2236+ transform_instances : List [str ] = None ,
2237+ ):
2238+ """Inference specification to be set for the model package
2239+
2240+ Args:
2241+ containers (dict): The Amazon ECR registry path of the Docker image
2242+ that contains the inference code.
2243+ image_uris (List[str]): The ECR path where inference code is stored.
2244+ content_types (list[str]): The supported MIME types
2245+ for the input data.
2246+ response_types (list[str]): The supported MIME types
2247+ for the output data.
2248+ inference_instances (list[str]): A list of the instance
2249+ types that are used to generate inferences in real-time (default: None).
2250+ transform_instances (list[str]): A list of the instance
2251+ types on which a transformation job can be run or on which an endpoint can be
2252+ deployed (default: None).
2253+
2254+ """
2255+ sagemaker_session = self .sagemaker_session or sagemaker .Session ()
2256+ if (containers is not None ) ^ (image_uris is None ):
2257+ raise ValueError ("Should have either containers or image_uris for inference." )
2258+ container_def = []
2259+ if image_uris :
2260+ for uri in image_uris :
2261+ container_def .append (
2262+ {
2263+ "Image" : uri ,
2264+ }
2265+ )
2266+ else :
2267+ container_def = containers
2268+
2269+ model_package_update_args = get_update_model_package_inference_args (
2270+ model_package_arn = self .model_package_arn ,
2271+ containers = container_def ,
2272+ content_types = content_types ,
2273+ response_types = response_types ,
2274+ inference_instances = inference_instances ,
2275+ transform_instances = transform_instances ,
2276+ )
2277+
2278+ sagemaker_session .sagemaker_client .update_model_package (** model_package_update_args )
2279+
2280+ def update_source_uri (
2281+ self ,
2282+ source_uri : str ,
2283+ ):
2284+ """Source uri to be set for the model package
2285+
2286+ Args:
2287+ source_uri (str): The URI of the source for the model package.
2288+
2289+ """
2290+ update_source_uri_args = {
2291+ "ModelPackageArn" : self .model_package_arn ,
2292+ "SourceUri" : source_uri ,
2293+ }
2294+ sagemaker_session = self .sagemaker_session or sagemaker .Session ()
2295+ sagemaker_session .sagemaker_client .update_model_package (** update_source_uri_args )
2296+
22252297 def remove_customer_metadata_properties (
22262298 self , customer_metadata_properties_to_remove : List [str ]
22272299 ):
0 commit comments